class Viewer(widgets.DOMWidget): """ Generic object for viewing and labeling Candidate objects in their rendered Contexts. """ _view_name = Unicode('ViewerView').tag(sync=True) _view_module = Unicode('viewer').tag(sync=True) cids = List().tag(sync=True) html = Unicode('<h3>Error!</h3>').tag(sync=True) _labels_serialized = Unicode().tag(sync=True) _selected_cid = Int().tag(sync=True) def __init__(self, candidates, session, gold=[], n_per_page=3, height=225, annotator_name=None): """ Initializes a Viewer. The Viewer uses the keyword argument annotator_name to define a AnnotatorLabelKey with that name. :param candidates: A Python container of Candidates (e.g., not a CandidateSet, but candidate_set.candidates) :param session: The SnorkelSession for the database backend :param gold: Optional, Python container of Candidates that are know to have positive labels :param n_per_page: Optional, number of Contexts to display per page :param height: Optional, the height in pixels of the Viewer :param annotator_name: Name of the human using the Viewer, for saving their work. Defaults to system username. """ super(Viewer, self).__init__() self.session = session # By default, use the username as annotator name name = annotator_name if annotator_name is not None else getpass.getuser( ) # Sets up the AnnotationKey to use self.annotator = self.session.query(GoldLabelKey).filter( GoldLabelKey.name == name).first() if self.annotator is None: self.annotator = GoldLabelKey(name=name) session.add(self.annotator) session.commit() # Viewer display configs self.n_per_page = n_per_page self.height = height # Note that the candidates are not necessarily commited to the DB, so they *may not have* non-null ids # Hence, we index by their position in this list # We get the sorted candidates and all contexts required, either from unary or binary candidates self.gold = list(gold) self.candidates = sorted(list(candidates), key=lambda c: c[0].char_start) self.contexts = list( set(c[0].get_parent() for c in self.candidates + self.gold)) # If committed, sort contexts by id try: self.contexts = sorted(self.contexts, key=lambda c: c.id) except: pass # Loads existing annotations self.annotations = [None] * len(self.candidates) self.annotations_stable = [None] * len(self.candidates) init_labels_serialized = [] for i, candidate in enumerate(self.candidates): # First look for the annotation in the primary annotations table existing_annotation = self.session.query(GoldLabel) \ .filter(GoldLabel.key == self.annotator) \ .filter(GoldLabel.candidate == candidate) \ .first() if existing_annotation is not None: self.annotations[i] = existing_annotation if existing_annotation.value == 1: value_string = 'true' elif existing_annotation.value == -1: value_string = 'false' else: raise ValueError( str(existing_annotation) + ' has value not in {1, -1}, which Viewer does not support.' ) init_labels_serialized.append(str(i) + '~~' + value_string) # If the annotator label is in the main table, also get its stable version context_stable_ids = '~~'.join( [c.stable_id for c in candidate.get_contexts()]) existing_annotation_stable = self.session.query(StableLabel) \ .filter(StableLabel.context_stable_ids == context_stable_ids)\ .filter(StableLabel.annotator_name == name).one_or_none() # If stable version is not available, create it here # NOTE: This is for versioning issues, should be removed? if existing_annotation_stable is None: context_stable_ids = '~~'.join( [c.stable_id for c in candidate.get_contexts()]) existing_annotation_stable = StableLabel(context_stable_ids=context_stable_ids,\ annotator_name=self.annotator.name,\ split=candidate.split,\ value=existing_annotation.value) self.session.add(existing_annotation_stable) self.session.commit() self.annotations_stable[i] = existing_annotation_stable self._labels_serialized = ','.join(init_labels_serialized) # Configures message handler self.on_msg(self.handle_label_event) # display js, construct html and pass on to widget model self.render() def _tag_span(self, html, cids, gold=False): """ Create the span around a segment of the context associated with one or more candidates / gold annotations """ classes = ['candidate'] if len(cids) > 0 else [] classes += ['gold-annotation'] if gold else [] classes += list(map(str, cids)) # Scrub for non-ascii characters; replace with ? return u'<span class="{classes}">{html}</span>'.format( classes=' '.join(classes), html=html) def _tag_context(self, context, candidates, gold): """Given the raw context, tag the spans using the generic _tag_span method""" raise NotImplementedError() def render(self): """Renders viewer pane""" cids = [] # Iterate over pages of contexts pid = 0 pages = [] N = len(self.contexts) for i in range(0, N, self.n_per_page): page_cids = [] lis = [] for j in range(i, min(N, i + self.n_per_page)): context = self.contexts[j] # Get the candidates in this context candidates = [ c for c in self.candidates if c[0].get_parent() == context ] gold = [g for g in self.gold if g.context_id == context.id] # Construct the <li> and page view elements li_data = self._tag_context(context, candidates, gold) lis.append(LI_HTML.format(data=li_data, context_id=context.id)) page_cids.append( [self.candidates.index(c) for c in candidates]) # Assemble the page... pages.append( PAGE_HTML.format( pid=pid, data=''.join(lis), etc=' style="display: block;"' if i == 0 else '')) cids.append(page_cids) pid += 1 # Render in primary Viewer template self.cids = cids self.html = open(HOME + '/viewer/viewer.html').read() % ( self.height, ''.join(pages)) display(Javascript(open(HOME + '/viewer/viewer.js').read())) def _get_labels(self): """ De-serialize labels from Javascript widget, map to internal candidate id, and return as list of tuples """ LABEL_MAP = {'true': 1, 'false': -1} labels = [ x.split('~~') for x in self._labels_serialized.split(',') if len(x) > 0 ] vals = [(int(cid), LABEL_MAP.get(l, 0)) for cid, l in labels] return vals def handle_label_event(self, _, content, buffers): """ Handles label event by persisting new label """ if content.get('event', '') == 'set_label': cid = content.get('cid', None) value = content.get('value', None) if value is True: value = 1 elif value is False: value = -1 else: raise ValueError('Unexpected label returned from widget: ' + str(value) + '. Expected values are True and False.') # If label already exists, just update value (in both AnnotatorLabel and StableLabel) if self.annotations[cid] is not None: if self.annotations[cid].value != value: self.annotations[cid].value = value self.annotations_stable[cid].value = value self.session.commit() # Otherwise, create a AnnotatorLabel *and a StableLabel* else: candidate = self.candidates[cid] # Create AnnotatorLabel self.annotations[cid] = GoldLabel(key=self.annotator, candidate=candidate, value=value) self.session.add(self.annotations[cid]) # Create StableLabel context_stable_ids = '~~'.join( [c.stable_id for c in candidate.get_contexts()]) self.annotations_stable[cid] = StableLabel(context_stable_ids=context_stable_ids,\ annotator_name=self.annotator.name,\ value=value,\ split=candidate.split) self.session.add(self.annotations_stable[cid]) self.session.commit() elif content.get('event', '') == 'delete_label': cid = content.get('cid', None) self.session.delete(self.annotations[cid]) self.annotations[cid] = None self.session.delete(self.annotations_stable[cid]) self.annotations_stable[cid] = None self.session.commit() def get_selected(self): return self.candidates[self._selected_cid]
class Widget(LoggingConfigurable): #------------------------------------------------------------------------- # Class attributes #------------------------------------------------------------------------- _widget_construction_callback = None _read_only_enabled = True widgets = {} widget_types = {} @staticmethod def on_widget_constructed(callback): """Registers a callback to be called when a widget is constructed. The callback must have the following signature: callback(widget)""" Widget._widget_construction_callback = callback @staticmethod def _call_widget_constructed(widget): """Static method, called when a widget is constructed.""" if Widget._widget_construction_callback is not None and callable(Widget._widget_construction_callback): Widget._widget_construction_callback(widget) @staticmethod def handle_comm_opened(comm, msg): """Static method, called when a widget is constructed.""" widget_class = import_item(str(msg['content']['data']['widget_class'])) widget = widget_class(comm=comm) #------------------------------------------------------------------------- # Traits #------------------------------------------------------------------------- _model_module = Unicode(None, allow_none=True, help="""A requirejs module name in which to find _model_name. If empty, look in the global registry.""") _model_name = Unicode('WidgetModel', help="""Name of the backbone model registered in the front-end to create and sync this widget with.""") _view_module = Unicode(help="""A requirejs module in which to find _view_name. If empty, look in the global registry.""", sync=True) _view_name = Unicode(None, allow_none=True, help="""Default view registered in the front-end to use to represent the widget.""", sync=True) comm = Instance('ipykernel.comm.Comm', allow_none=True) msg_throttle = Int(3, sync=True, help="""Maximum number of msgs the front-end can send before receiving an idle msg from the back-end.""") version = Int(0, sync=True, help="""Widget's version""") keys = List() def _keys_default(self): return [name for name in self.traits(sync=True)] _property_lock = Dict() _holding_sync = False _states_to_send = Set() _display_callbacks = Instance(CallbackDispatcher, ()) _msg_callbacks = Instance(CallbackDispatcher, ()) #------------------------------------------------------------------------- # (Con/de)structor #------------------------------------------------------------------------- def __init__(self, **kwargs): """Public constructor""" self._model_id = kwargs.pop('model_id', None) super(Widget, self).__init__(**kwargs) Widget._call_widget_constructed(self) self.open() def __del__(self): """Object disposal""" self.close() #------------------------------------------------------------------------- # Properties #------------------------------------------------------------------------- def open(self): """Open a comm to the frontend if one isn't already open.""" if self.comm is None: args = dict(target_name='ipython.widget', data={'model_name': self._model_name, 'model_module': self._model_module}) if self._model_id is not None: args['comm_id'] = self._model_id self.comm = Comm(**args) def _comm_changed(self, name, new): """Called when the comm is changed.""" if new is None: return self._model_id = self.model_id self.comm.on_msg(self._handle_msg) Widget.widgets[self.model_id] = self # first update self.send_state() @property def model_id(self): """Gets the model id of this widget. If a Comm doesn't exist yet, a Comm will be created automagically.""" return self.comm.comm_id #------------------------------------------------------------------------- # Methods #------------------------------------------------------------------------- def __setattr__(self, name, value): """Overload of HasTraits.__setattr__to handle read-only-ness of widget attributes """ if (self._read_only_enabled and self.has_trait(name) and self.trait_metadata(name, 'read_only')): raise TraitError('Widget attribute "%s" is read-only.' % name) else: super(Widget, self).__setattr__(name, value) def close(self): """Close method. Closes the underlying comm. When the comm is closed, all of the widget views are automatically removed from the front-end.""" if self.comm is not None: Widget.widgets.pop(self.model_id, None) self.comm.close() self.comm = None def send_state(self, key=None): """Sends the widget state, or a piece of it, to the front-end. Parameters ---------- key : unicode, or iterable (optional) A single property's name or iterable of property names to sync with the front-end. """ state = self.get_state(key=key) buffer_keys, buffers = [], [] for k, v in state.items(): if isinstance(v, memoryview): state.pop(k) buffers.append(v) buffer_keys.append(k) msg = {'method': 'update', 'state': state, 'buffers': buffer_keys} self._send(msg, buffers=buffers) def get_state(self, key=None): """Gets the widget state, or a piece of it. Parameters ---------- key : unicode or iterable (optional) A single property's name or iterable of property names to get. Returns ------- state : dict of states metadata : dict metadata for each field: {key: metadata} """ if key is None: keys = self.keys elif isinstance(key, string_types): keys = [key] elif isinstance(key, collections.Iterable): keys = key else: raise ValueError("key must be a string, an iterable of keys, or None") state = {} for k in keys: to_json = self.trait_metadata(k, 'to_json', self._trait_to_json) state[k] = to_json(getattr(self, k), self) return state def set_state(self, sync_data): """Called when a state is received from the front-end.""" # The order of these context managers is important. Properties must # be locked when the hold_trait_notification context manager is # released and notifications are fired. with self._allow_write(),\ self._lock_property(**sync_data),\ self.hold_trait_notifications(): for name in sync_data: if name in self.keys: from_json = self.trait_metadata(name, 'from_json', self._trait_from_json) setattr(self, name, from_json(sync_data[name], self)) def send(self, content, buffers=None): """Sends a custom msg to the widget model in the front-end. Parameters ---------- content : dict Content of the message to send. buffers : list of binary buffers Binary buffers to send with message """ self._send({"method": "custom", "content": content}, buffers=buffers) def on_msg(self, callback, remove=False): """(Un)Register a custom msg receive callback. Parameters ---------- callback: callable callback will be passed three arguments when a message arrives:: callback(widget, content, buffers) remove: bool True if the callback should be unregistered.""" self._msg_callbacks.register_callback(callback, remove=remove) def on_displayed(self, callback, remove=False): """(Un)Register a widget displayed callback. Parameters ---------- callback: method handler Must have a signature of:: callback(widget, **kwargs) kwargs from display are passed through without modification. remove: bool True if the callback should be unregistered.""" self._display_callbacks.register_callback(callback, remove=remove) def add_traits(self, **traits): """Dynamically add trait attributes to the Widget.""" super(Widget, self).add_traits(**traits) for name, trait in traits.items(): if trait.get_metadata('sync'): self.keys.append(name) self.send_state(name) #------------------------------------------------------------------------- # Support methods #------------------------------------------------------------------------- @contextmanager def _lock_property(self, **properties): """Lock a property-value pair. The value should be the JSON state of the property. NOTE: This, in addition to the single lock for all state changes, is flawed. In the future we may want to look into buffering state changes back to the front-end.""" self._property_lock = properties try: yield finally: self._property_lock = {} @contextmanager def _allow_write(self): if self._read_only_enabled is False: yield else: try: self._read_only_enabled = False yield finally: self._read_only_enabled = True @contextmanager def hold_sync(self): """Hold syncing any state until the outermost context manager exits""" if self._holding_sync is True: yield else: try: self._holding_sync = True yield finally: self._holding_sync = False self.send_state(self._states_to_send) self._states_to_send.clear() def _should_send_property(self, key, value): """Check the property lock (property_lock)""" to_json = self.trait_metadata(key, 'to_json', self._trait_to_json) if (key in self._property_lock and to_json(value, self) == self._property_lock[key]): return False elif self._holding_sync: self._states_to_send.add(key) return False else: return True # Event handlers @_show_traceback def _handle_msg(self, msg): """Called when a msg is received from the front-end""" data = msg['content']['data'] method = data['method'] # Handle backbone sync methods CREATE, PATCH, and UPDATE all in one. if method == 'backbone': if 'sync_data' in data: # get binary buffers too sync_data = data['sync_data'] for i,k in enumerate(data.get('buffer_keys', [])): sync_data[k] = msg['buffers'][i] self.set_state(sync_data) # handles all methods # Handle a state request. elif method == 'request_state': self.send_state() # Handle a custom msg from the front-end. elif method == 'custom': if 'content' in data: self._handle_custom_msg(data['content'], msg['buffers']) # Catch remainder. else: self.log.error('Unknown front-end to back-end widget msg with method "%s"' % method) def _handle_custom_msg(self, content, buffers): """Called when a custom msg is received.""" self._msg_callbacks(self, content, buffers) def _notify_trait(self, name, old_value, new_value): """Called when a property has been changed.""" # Trigger default traitlet callback machinery. This allows any user # registered validation to be processed prior to allowing the widget # machinery to handle the state. LoggingConfigurable._notify_trait(self, name, old_value, new_value) # Send the state after the user registered callbacks for trait changes # have all fired (allows for user to validate values). if self.comm is not None and name in self.keys: # Make sure this isn't information that the front-end just sent us. if self._should_send_property(name, new_value): # Send new state to front-end self.send_state(key=name) def _handle_displayed(self, **kwargs): """Called when a view has been displayed for this widget instance""" self._display_callbacks(self, **kwargs) @staticmethod def _trait_to_json(x, self): """Convert a trait value to json.""" return x @staticmethod def _trait_from_json(x, self): """Convert json values to objects.""" return x def _ipython_display_(self, **kwargs): """Called when `IPython.display.display` is called on the widget.""" # Show view. if self._view_name is not None: self._send({"method": "display"}) self._handle_displayed(**kwargs) def _send(self, msg, buffers=None): """Sends a message to the model in the front-end.""" self.comm.send(data=msg, buffers=buffers)
class CypherMagic(Magics, Configurable): """Runs Cypher statement on a database, specified by a connect string. Provides the %%cypher magic.""" auto_limit = Int(defaults.auto_limit, config=True, help=""" Automatically limit the size of the returned result sets """) style = Unicode(defaults.style, config=True, help=""" Set the table printing style to any of prettytable's defined styles (currently DEFAULT, MSWORD_FRIENDLY, PLAIN_COLUMNS, RANDOM) """) short_errors = Bool(defaults.short_errors, config=True, help=""" Don't display the full traceback on Neo4j errors """) data_contents = Bool(defaults.data_contents, config=True, help=""" Bring extra data to render the results as a graph """) display_limit = Int(defaults.display_limit, config=True, help=""" Automatically limit the number of rows displayed (full result set is still stored) """) auto_pandas = Bool(defaults.auto_pandas, config=True, help=""" Return Pandas DataFrame instead of regular result sets """) auto_html = Bool(defaults.auto_html, config=True, help=""" Return a D3 representation of the graph instead of regular result sets """) auto_networkx = Bool(defaults.auto_networkx, config=True, help=""" Return Networkx MultiDiGraph instead of regular result sets """) rest = Bool(defaults.rest, config=True, help=""" Return full REST representations of objects inside the result sets """) feedback = Bool(defaults.feedback, config=True, help=""" Print number of rows affected """) uri = Unicode(defaults.uri, config=True, help=""" Default database URL if none is defined inline """) def __init__(self, shell): Configurable.__init__(self, config=shell.config) Magics.__init__(self, shell=shell) # Add ourself to the list of module configurable via %config self.shell.configurables.append(self) self._legal_cypher_identifier = re.compile(r'^[A-Za-z0-9#_$]+') @needs_local_scope @line_magic('cypher') @cell_magic('cypher') def execute(self, line, cell='', local_ns={}): """Runs Cypher statement against a Neo4j graph database, specified by a connect string. If no database connection has been established, first word should be a connection string, or the user@host name of an established connection. Otherwise, http://localhost:7474/db/data will be assumed. Examples:: %%cypher https://me:mypw@myhost:7474/db/data START n=node(*) RETURN n %%cypher me@myhost START n=node(*) RETURN n %%cypher START n=node(*) RETURN n Connect string syntax examples: http://localhost:7474/db/data https://me:mypw@localhost:7474/db/data """ # save globals and locals so they can be referenced in bind vars user_ns = self.shell.user_ns user_ns.update(local_ns) parsed = parse("""{0}\n{1}""".format(line, cell), self) conn = Connection.get(parsed['as'] or parsed['uri']) first_word = parsed['cypher'].split(None, 1)[:1] if first_word and first_word[0].lower() == 'persist': return self._persist_dataframe(parsed['cypher'], conn, user_ns) try: result = run(parsed['cypher'], user_ns, self, conn) return result except StatusException as e: if self.short_errors: print(e) else: raise def _persist_dataframe(self, raw, conn, user_ns): if not DataFrame: raise ImportError("Must `pip install pandas` to use DataFrames") pieces = raw.split() if len(pieces) != 2: raise SyntaxError( "Format: %%cypher [connection] persist <DataFrameName>") frame_name = pieces[1].strip(';') frame = eval(frame_name, user_ns) if not isinstance(frame, DataFrame) and not isinstance(frame, Series): raise TypeError('%s is not a Pandas DataFrame or Series' % frame_name) table_name = frame_name.lower() table_name = self._legal_cypher_identifier.search(table_name).group(0) frame.to_sql(table_name, conn.session.engine) return 'Persisted %s' % table_name
class NamespacedResourceReflector(LoggingConfigurable): """ Base class for keeping a local up-to-date copy of a set of kubernetes resources. Must be subclassed once per kind of resource that needs watching. """ labels = Dict({}, config=True, help=""" Labels to reflect onto local cache """) fields = Dict({}, config=True, help=""" Fields to restrict the reflected objects """) namespace = Unicode(None, allow_none=True, help=""" Namespace to watch for resources in """) resources = Dict({}, help=""" Dictionary of resource names to the appropriate resource objects. This can be accessed across threads safely. """) kind = Unicode('resource', help=""" Human readable name for kind of object we're watching for. Used for diagnostic messages. """) list_method_name = Unicode("", help=""" Name of function (on apigroup respresented by `api_group_name`) that is to be called to list resources. This will be passed a namespace & a label selector. You most likely want something of the form list_namespaced_<resource> - for example, `list_namespaced_pod` will give you a PodReflector. This must be set by a subclass. """) api_group_name = Unicode('CoreV1Api', help=""" Name of class that represents the apigroup on which `list_method_name` is to be found. Defaults to CoreV1Api, which has everything in the 'core' API group. If you want to watch Ingresses, for example, you would have to use ExtensionsV1beta1Api """) request_timeout = Int(60, config=True, help=""" Network timeout for kubernetes watch. Trigger watch reconnect when a given request is taking too long, which can indicate network issues. """) timeout_seconds = Int(10, config=True, help=""" Timeout for kubernetes watch. Trigger watch reconnect when no watch event has been received. This will cause a full reload of the currently existing resources from the API server. """) on_failure = Any( help="""Function to be called when the reflector gives up.""") def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # Load kubernetes config here, since this is a Singleton and # so this __init__ will be run way before anything else gets run. try: config.load_incluster_config() except config.ConfigException: config.load_kube_config() self.api = shared_client(self.api_group_name) # FIXME: Protect against malicious labels? self.label_selector = ','.join( ['{}={}'.format(k, v) for k, v in self.labels.items()]) self.field_selector = ','.join( ['{}={}'.format(k, v) for k, v in self.fields.items()]) self.first_load_future = Future() self._stop_event = threading.Event() self.start() def __del__(self): self.stop() def _list_and_update(self): """ Update current list of resources by doing a full fetch. Overwrites all current resource info. """ initial_resources = getattr(self.api, self.list_method_name)( self.namespace, label_selector=self.label_selector, field_selector=self.field_selector, _request_timeout=self.request_timeout, ) # This is an atomic operation on the dictionary! self.resources = {p.metadata.name: p for p in initial_resources.items} # return the resource version so we can hook up a watch return initial_resources.metadata.resource_version def _watch_and_update(self): """ Keeps the current list of resources up-to-date This method is to be run not on the main thread! We first fetch the list of current resources, and store that. Then we register to be notified of changes to those resources, and keep our local store up-to-date based on these notifications. We also perform exponential backoff, giving up after we hit 32s wait time. This should protect against network connections dropping and intermittent unavailability of the api-server. Every time we recover from an exception we also do a full fetch, to pick up changes that might've been missed in the time we were not doing a watch. Note that we're playing a bit with fire here, by updating a dictionary in this thread while it is probably being read in another thread without using locks! However, dictionary access itself is atomic, and as long as we don't try to mutate them (do a 'fetch / modify / update' cycle on them), we should be ok! """ selectors = [] log_name = "" if self.label_selector: selectors.append("label selector=%r" % self.label_selector) if self.field_selector: selectors.append("field selector=%r" % self.field_selector) log_selector = ', '.join(selectors) cur_delay = 0.1 self.log.info( "watching for %s with %s in namespace %s", self.kind, log_selector, self.namespace, ) while True: self.log.debug("Connecting %s watcher", self.kind) w = watch.Watch() try: resource_version = self._list_and_update() if not self.first_load_future.done(): # signal that we've loaded our initial data self.first_load_future.set_result(None) watch_args = { 'namespace': self.namespace, 'label_selector': self.label_selector, 'field_selector': self.field_selector, 'resource_version': resource_version, } if self.request_timeout: # set network receive timeout watch_args['_request_timeout'] = self.request_timeout if self.timeout_seconds: # set watch timeout watch_args['timeout_seconds'] = self.timeout_seconds # in case of timeout_seconds, the w.stream just exits (no exception thrown) # -> we stop the watcher and start a new one for ev in w.stream(getattr(self.api, self.list_method_name), **watch_args): cur_delay = 0.1 resource = ev['object'] if ev['type'] == 'DELETED': # This is an atomic delete operation on the dictionary! self.resources.pop(resource.metadata.name, None) else: # This is an atomic operation on the dictionary! self.resources[resource.metadata.name] = resource if self._stop_event.is_set(): self.log.info("%s watcher stopped", self.kind) break except ReadTimeoutError: # network read time out, just continue and restart the watch # this could be due to a network problem or just low activity self.log.warning("Read timeout watching %s, reconnecting", self.kind) continue except Exception: cur_delay = cur_delay * 2 if cur_delay > 30: self.log.exception( "Watching resources never recovered, giving up") if self.on_failure: self.on_failure() return self.log.exception( "Error when watching resources, retrying in %ss", cur_delay) time.sleep(cur_delay) continue else: # no events on watch, reconnect self.log.debug("%s watcher timeout", self.kind) finally: w.stop() if self._stop_event.is_set(): self.log.info("%s watcher stopped", self.kind) break self.log.warning("%s watcher finished", self.kind) def start(self): """ Start the reflection process! We'll do a blocking read of all resources first, so that we don't race with any operations that are checking the state of the pod store - such as polls. This should be called only once at the start of program initialization (when the singleton is being created), and not afterwards! """ if hasattr(self, 'watch_thread'): raise ValueError( 'Thread watching for resources is already running') self._list_and_update() self.watch_thread = threading.Thread(target=self._watch_and_update) # If the watch_thread is only thread left alive, exit app self.watch_thread.daemon = True self.watch_thread.start() def stop(self): self._stop_event.set() def stopped(self): return self._stop_event.is_set()
class _Selection(DescriptionWidget, ValueWidget, CoreWidget): """Base class for Selection widgets ``options`` can be specified as a list of values or a list of (label, value) tuples. The labels are the strings that will be displayed in the UI, representing the actual Python choices, and should be unique. If labels are not specified, they are generated from the values. When programmatically setting the value, a reverse lookup is performed among the options to check that the value is valid. The reverse lookup uses the equality operator by default, but another predicate may be provided via the ``equals`` keyword argument. For example, when dealing with numpy arrays, one may set equals=np.array_equal. """ value = Any(None, help="Selected value", allow_none=True) label = Unicode(None, help="Selected label", allow_none=True) index = Int(None, help="Selected index", allow_none=True).tag(sync=True) options = Any( (), help= """Iterable of values or (label, value) pairs that the user can select. The labels are the strings that will be displayed in the UI, representing the actual Python choices, and should be unique. """) _options_full = None # This being read-only means that it cannot be changed by the user. _options_labels = TypedTuple( trait=Unicode(), read_only=True, help="The labels for the options.").tag(sync=True) disabled = Bool(help="Enable or disable user changes").tag(sync=True) def __init__(self, *args, **kwargs): self.equals = kwargs.pop('equals', lambda x, y: x == y) # We have to make the basic options bookkeeping consistent # so we don't have errors the first time validators run self._initializing_traits_ = True kwargs['options'] = _exhaust_iterable(kwargs.get('options', ())) self._options_full = _make_options(kwargs['options']) self._propagate_options(None) # Select the first item by default, if we can if 'index' not in kwargs and 'value' not in kwargs and 'label' not in kwargs: options = self._options_full nonempty = (len(options) > 0) kwargs['index'] = 0 if nonempty else None kwargs['label'], kwargs['value'] = options[0] if nonempty else ( None, None) super().__init__(*args, **kwargs) self._initializing_traits_ = False @validate('options') def _validate_options(self, proposal): # if an iterator is provided, exhaust it proposal.value = _exhaust_iterable(proposal.value) # throws an error if there is a problem converting to full form self._options_full = _make_options(proposal.value) return proposal.value @observe('options') def _propagate_options(self, change): "Set the values and labels, and select the first option if we aren't initializing" options = self._options_full self.set_trait('_options_labels', tuple(i[0] for i in options)) self._options_values = tuple(i[1] for i in options) if self._initializing_traits_ is not True: if len(options) > 0: if self.index == 0: # Explicitly trigger the observers to pick up the new value and # label. Just setting the value would not trigger the observers # since traitlets thinks the value hasn't changed. self._notify_trait('index', 0, 0) else: self.index = 0 else: self.index = None @validate('index') def _validate_index(self, proposal): if proposal.value is None or 0 <= proposal.value < len( self._options_labels): return proposal.value else: raise TraitError('Invalid selection: index out of bounds') @observe('index') def _propagate_index(self, change): "Propagate changes in index to the value and label properties" label = self._options_labels[ change.new] if change.new is not None else None value = self._options_values[ change.new] if change.new is not None else None if self.label is not label: self.label = label if self.value is not value: self.value = value @validate('value') def _validate_value(self, proposal): value = proposal.value try: return findvalue(self._options_values, value, self.equals) if value is not None else None except ValueError: raise TraitError('Invalid selection: value not found') @observe('value') def _propagate_value(self, change): if change.new is None: index = None elif self.index is not None and self._options_values[ self.index] == change.new: index = self.index else: index = self._options_values.index(change.new) if self.index != index: self.index = index @validate('label') def _validate_label(self, proposal): if (proposal.value is not None) and (proposal.value not in self._options_labels): raise TraitError('Invalid selection: label not found') return proposal.value @observe('label') def _propagate_label(self, change): if change.new is None: index = None elif self.index is not None and self._options_labels[ self.index] == change.new: index = self.index else: index = self._options_labels.index(change.new) if self.index != index: self.index = index def _repr_keys(self): keys = super()._repr_keys() # Include options manually, as it isn't marked as synced: for key in sorted(chain(keys, ('options', ))): if key == 'index' and self.index == 0: # Index 0 is default when there are options continue yield key
class Builder(widgets.DOMWidget): """A Python wrapper for the Escher metabolic map. This map will also show data on reactions, metabolites, or genes. The Builder is a Jupyter widget that can be viewed in a Jupyter notebook or in Jupyter Lab. It can also be used to create a standalone HTML file for the map with the save_html() function. Maps are downloaded from the Escher website if found by name. :param int height: The height of the Escher Jupyter widget in pixels. :param str map_name: A string specifying a map to be downloaded from the Escher website. :param str map_json: A JSON string, or a file path to a JSON file, or a URL specifying a JSON file to be downloaded. :param model: A COBRApy model. :param model_name: A string specifying a model to be downloaded from the Escher web server. :param model_json: A JSON string, or a file path to a JSON file, or a URL specifying a JSON file to be downloaded. :param embedded_css: The CSS (as a string) to be embedded with the Escher SVG. In Jupyter, if you change embedded_css on an existing builder instance, the Builder must be restarted for this to take effect (e.g. by re-evaluating the widget in a cell). You can use the default embedded css as a starting point: https://github.com/zakandrewking/escher/blob/master/src/Builder-embed.css :param reaction_data: A dictionary with keys that correspond to reaction IDs and values that will be mapped to reaction arrows and labels. :param metabolite_data: A dictionary with keys that correspond to metabolite IDs and values that will be mapped to metabolite nodes and labels. :param gene_data: A dictionary with keys that correspond to gene IDs and values that will be mapped to corresponding reactions. **Keyword Arguments** You can also pass in any of the following options as keyword arguments. The details on each of these are provided in the JavaScript API documentation: - use_3d_transform - menu - scroll_behavior - use_3d_transform - enable_editing - enable_keys - enable_search - zoom_to_element - full_screen_button - disabled_buttons - semantic_zoom - starting_reaction - never_ask_before_quit - primary_metabolite_radius - secondary_metabolite_radius - marker_radius - gene_font_size - hide_secondary_metabolites - show_gene_reaction_rules - hide_all_labels - canvas_size_and_loc - reaction_data - reaction_styles - reaction_opacity - reaction_compare_style - reaction_scale - reaction_no_data_color - reaction_no_data_size - reaction_highlight - gene_data - and_method_in_gene_reaction_rule - metabolite_data - metabolite_styles - metabolite_compare_style - metabolite_scale - metabolite_no_data_color - metabolite_no_data_size - identifiers_on_map - highlight_missing - allow_building_duplicate_reactions - cofactors - enable_tooltips - enable_keys_with_tooltip - reaction_scale_preset - metabolite_scale_preset - primary_metabolite_radius - secondary_metabolite_radius - marker_radius - gene_font_size - reaction_no_data_size - metabolite_no_data_size If any of these is set to None, the default (or most-recent) value is used. To turn off a setting, use False instead. All arguments can also be set by assigning the property of an an existing Builder object, e.g.: .. code:: python my_builder.map_name = 'iJO1366.Central metabolism' """ # widget info traitlets _view_name = Unicode('EscherMapView').tag(sync=True) _model_name = Unicode('EscherMapModel').tag(sync=True) _view_module = Unicode('escher').tag(sync=True) _model_module = Unicode('escher').tag(sync=True) _view_module_version = Unicode(__version__).tag(sync=True) _model_module_version = Unicode(__version__).tag(sync=True) # Python package options height = Int(500).tag(sync=True) embedded_css = Unicode(None, allow_none=True).tag(sync=True) @validate('embedded_css') def _validate_embedded_css(self, proposal): css = proposal['value'] if css: return css.replace('\n', '') else: return None # synced data _loaded_map_json = Unicode(None, allow_none=True).tag(sync=True) @observe('_loaded_map_json') def _observe_loaded_map_json(self, change): # if map is cleared, then clear these if not change.new: self.map_name = None self.map_json = None _loaded_model_json = Unicode(None, allow_none=True).tag(sync=True) @observe('_loaded_model_json') def _observe_loaded_model_json(self, change): # if model is cleared, then clear these if not change.new: self.model = None self.model_name = None self.model_json = None # Python options that are indirectly synced to the widget map_name = Unicode(None, allow_none=True) @observe('map_name') def _observe_map_name(self, change): if change.new: self._loaded_map_json = map_json_for_name(change.new) else: self._loaded_map_json = None map_json = Unicode(None, allow_none=True) @observe('map_json') def _observe_map_json(self, change): if change.new: self._loaded_map_json = _load_resource(change.new, 'map_json') else: self._loaded_map_json = None model = Instance(Model, allow_none=True) @observe('model') def _observe_model(self, change): if change.new: self._loaded_model_json = cobra.io.to_json(change.new) else: self._loaded_model_json = None model_name = Unicode(None, allow_none=True) @observe('model_name') def _observe_model_name(self, change): if change.new: self._loaded_model_json = model_json_for_name(change.new) else: self._loaded_model_json = None model_json = Unicode(None, allow_none=True) @observe('model_json') def _observe_model_json(self, change): if change.new: self._loaded_model_json = _load_resource(change.new, 'model_json') else: self._loaded_model_json = None # Synced options passed as an object to JavaScript Builder menu = Any(None, allow_none=True)\ .tag(sync=True, option=True) scroll_behavior = Any('none', allow_none=False)\ .tag(sync=True, option=True) use_3d_transform = Any(None, allow_none=True)\ .tag(sync=True, option=True) enable_editing = Any(None, allow_none=True)\ .tag(sync=True, option=True) enable_keys = Any(False, allow_none=True)\ .tag(sync=True, option=True) enable_search = Any(None, allow_none=True)\ .tag(sync=True, option=True) zoom_to_element = Any(None, allow_none=True)\ .tag(sync=True, option=True) full_screen_button_default = { 'enable_keys': True, 'scroll_behavior': 'pan', 'enable_editing': True, 'menu': 'all', 'enable_tooltips': ['label'] } full_screen_button = Any(full_screen_button_default, allow_none=True)\ .tag(sync=True, option=True) disabled_buttons = Any(None, allow_none=True)\ .tag(sync=True, option=True) semantic_zoom = Any(None, allow_none=True)\ .tag(sync=True, option=True) starting_reaction = Any(None, allow_none=True)\ .tag(sync=True, option=True) never_ask_before_quit = Any(None, allow_none=True)\ .tag(sync=True, option=True) primary_metabolite_radius = Any(None, allow_none=True)\ .tag(sync=True, option=True) secondary_metabolite_radius = Any(None, allow_none=True)\ .tag(sync=True, option=True) marker_radius = Any(None, allow_none=True)\ .tag(sync=True, option=True) gene_font_size = Any(None, allow_none=True)\ .tag(sync=True, option=True) hide_secondary_metabolites = Any(None, allow_none=True)\ .tag(sync=True, option=True) show_gene_reaction_rules = Any(None, allow_none=True)\ .tag(sync=True, option=True) hide_all_labels = Any(None, allow_none=True)\ .tag(sync=True, option=True) canvas_size_and_loc = Any(None, allow_none=True)\ .tag(sync=True, option=True) reaction_data = Any(None, allow_none=True)\ .tag(sync=True, option=True) @validate('reaction_data') def _validate_reaction_data(self, proposal): try: return convert_data(proposal['value']) except Exception: raise Exception("""Invalid data for reaction_data. Must be pandas Series, pandas DataFrame, dict, list, or None""") reaction_styles = Any(None, allow_none=True)\ .tag(sync=True, option=True) reaction_opacity = Any(None, allow_none=True)\ .tag(sync=True, option=True) reaction_highlight = Any(None, allow_none=True)\ .tag(sync=True, option=True) reaction_compare_style = Any(None, allow_none=True)\ .tag(sync=True, option=True) reaction_scale = Any(None, allow_none=True)\ .tag(sync=True, option=True) reaction_no_data_color = Any(None, allow_none=True)\ .tag(sync=True, option=True) reaction_no_data_size = Any(None, allow_none=True)\ .tag(sync=True, option=True) gene_data = Any(None, allow_none=True)\ .tag(sync=True, option=True) @validate('gene_data') def _validate_gene_data(self, proposal): try: return convert_data(proposal['value']) except Exception: raise Exception("""Invalid data for gene_data. Must be pandas Series, pandas DataFrame, dict, list, or None""") and_method_in_gene_reaction_rule = Any(None, allow_none=True)\ .tag(sync=True, option=True) metabolite_data = Any(None, allow_none=True)\ .tag(sync=True, option=True) @validate('metabolite_data') def _validate_metabolite_data(self, proposal): try: return convert_data(proposal['value']) except Exception: raise Exception("""Invalid data for metabolite_data. Must be pandas Series, pandas DataFrame, dict, list, or None""") metabolite_styles = Any(None, allow_none=True)\ .tag(sync=True, option=True) metabolite_compare_style = Any(None, allow_none=True)\ .tag(sync=True, option=True) metabolite_scale = Any(None, allow_none=True)\ .tag(sync=True, option=True) metabolite_no_data_color = Any(None, allow_none=True)\ .tag(sync=True, option=True) metabolite_no_data_size = Any(None, allow_none=True)\ .tag(sync=True, option=True) identifiers_on_map = Any(None, allow_none=True)\ .tag(sync=True, option=True) highlight_missing = Any(None, allow_none=True)\ .tag(sync=True, option=True) allow_building_duplicate_reactions = Any(None, allow_none=True)\ .tag(sync=True, option=True) cofactors = Any(None, allow_none=True)\ .tag(sync=True, option=True) enable_tooltips = Any(False, allow_none=True)\ .tag(sync=True, option=True) enable_keys_with_tooltip = Any(None, allow_none=True)\ .tag(sync=True, option=True) reaction_scale_preset = Any(None, allow_none=True)\ .tag(sync=True, option=True) metabolite_scale_preset = Any(None, allow_none=True)\ .tag(sync=True, option=True) primary_metabolite_radius = Any(None, allow_none=True)\ .tag(sync=True, option=True) secondary_metabolite_radius = Any(None, allow_none=True)\ .tag(sync=True, option=True) marker_radius = Any(None, allow_none=True)\ .tag(sync=True, option=True) gene_font_size = Any(None, allow_none=True)\ .tag(sync=True, option=True) reaction_no_data_size = Any(None, allow_none=True)\ .tag(sync=True, option=True) metabolite_no_data_size = Any(None, allow_none=True)\ .tag(sync=True, option=True) def __init__( self, map_name: str = None, map_json: str = None, model: Model = None, model_name: str = None, model_json: str = None, **kwargs, ) -> None: # kwargs will instantiate the traitlets super().__init__(**kwargs) if map_json: if map_name: warn('map_json overrides map_name') self.map_json = map_json else: self.map_name = map_name if model: if model_name: warn('model overrides model_name') if model_json: warn('model overrides model_json') self.model = model elif model_json: if model_name: warn('model_json overrides model_name') self.model_json = model_json else: self.model_name = model_name unavailable_options = { 'fill_screen': """The fill_option screen is set automatically by the Escher Python package""", 'tooltip_component': """The tooltip_component cannot be customized with the Python API""", 'first_load_callback': """The first_load_callback cannot be customized with the Python API""", 'unique_map_id': """The option unique_map_id is deprecated""", 'ignore_bootstrap': """The option unique_map_id is deprecated""", } for key, val in kwargs.items(): if key in unavailable_options: warn(val) def display_in_notebook(self, *args, **kwargs): """Deprecated. The Builder is now a Jupyter Widget, so you can return the Builder object from a cell to display it, or you can manually call the IPython display function: from IPython.display import display from escher import Builder b = Builder(...) display(b) """ raise Exception(('display_in_notebook is deprecated. The Builder is ' 'now a Jupyter Widget, so you can return the ' 'Builder in a cell to see it, or use the IPython ' 'display function (see Escher docs for details)')) def display_in_browser(self, *args, **kwargs): """Deprecated. We recommend using the Jupyter Widget (which now supports all Escher features) or the save_html option to generate a standalone HTML file that loads the map. """ raise Exception(('display_in_browser is deprecated. We recommend using' 'the Jupyter Widget (which now supports all Escher' 'features) or the save_html option to generate a' 'standalone HTML file that loads the map.')) def save_html(self, filepath): """Save an HTML file containing the map. :param string filepath: The name of the HTML file. TODO apply options from self """ # options = transform(self.options) # get options options = {} for key in self.traits(option=True): val = getattr(self, key) if val is not None: options[key] = val options_json = json.dumps(options) template = env.get_template('standalone.html') embedded_css_b64 = (b64dump(self.embedded_css) if self.embedded_css is not None else None) html = template.render( escher_url=get_url('escher_min'), embedded_css_b64=embedded_css_b64, map_data_json_b64=b64dump(self._loaded_map_json), model_data_json_b64=b64dump(self._loaded_model_json), options_json_b64=b64dump(options_json), ) with open(expanduser(filepath), 'wb') as f: f.write(html.encode('utf-8'))
class TrajectoryPlayer(HasTraits): # should set default values here different from desired defaults # so `observe` can be triggered step = Int(0) sync_frame = Bool(True) interpolate = Bool(False) delay = Float(0.0) parameters = Dict() iparams = Dict() _interpolation_t = Float() _iterpolation_type = CaselessStrEnum(['linear', 'spline']) spin = Bool(False) _spin_x = Int(1) _spin_y = Int(0) _spin_z = Int(0) _spin_speed = Float(0.005) camera = CaselessStrEnum(['perspective', 'orthographic'], default_value='perspective') _render_params = Dict() _real_time_update = Bool(False) widget_tab = Any(None) widget_repr = Any(None) widget_repr_parameters = Any(None) widget_quick_repr = Any(None) widget_general = Any(None) widget_picked = Any(None) widget_preference = Any(None) widget_extra = Any(None) widget_theme = Any(None) widget_help = Any(None) widget_export_image = Any(None) widget_component_slider = Any(None) widget_repr_slider = Any(None) widget_repr_choices = Any(None) widget_repr_control_buttons = Any(None) widget_repr_add = Any(None) widget_accordion_repr_parameters = Any(None) widget_repr_parameters_dialog = Any(None) widget_repr_name = Any(None) widget_component_dropdown = Any(None) widget_drag = Any(None) def __init__(self, view, step=1, delay=100, sync_frame=False, min_delay=40): self._view = view self.step = step self.sync_frame = sync_frame self.delay = delay self.min_delay = min_delay self._interpolation_t = 0.5 self._iterpolation_type = 'linear' self.iparams = dict(t=self._interpolation_t, step=1, type=self._iterpolation_type) self._render_params = dict(factor=4, antialias=True, trim=False, transparent=False) self._widget_names = [w for w in dir(self) if w.startswith('wiget_')] def _update_padding(self, padding=default.DEFAULT_PADDING): widget_collection = [ self.widget_general, self.widget_repr, self.widget_preference, self.widget_repr_parameters, self.widget_help, self.widget_extra, self.widget_picked ] for widget in widget_collection: if widget is not None: widget.layout.padding = padding def _create_all_widgets(self): if self.widget_tab is None: self.widget_tab = self._display() old_index = self.widget_tab.selected_index for index, _ in enumerate(self.widget_tab.children): self.widget_tab.selected_index = index self.widget_tab.selected_index = old_index def smooth(self): self.interpolate = True @observe('camera') def on_camera_changed(self, change): camera_type = change['new'] self._view._remote_call("setParameters", target='Stage', kwargs=dict(cameraType=camera_type)) @property def frame(self): return self._view.frame @frame.setter def frame(self, value): self._view.frame = value @property def count(self): return self._view.count @observe('sync_frame') def update_sync_frame(self, change): value = change['new'] if value: self._view._set_sync_frame() else: self._view._set_unsync_frame() @observe("delay") def update_delay(self, change): delay = change['new'] self._view._set_delay(delay) @observe('parameters') def update_parameters(self, change): params = change['new'] self.sync_frame = params.get("sync_frame", self.sync_frame) self.delay = params.get("delay", self.delay) self.step = params.get("step", self.step) @observe('_interpolation_t') def _interpolation_t_changed(self, change): self.iparams['t'] = change['new'] @observe('spin') def on_spin_changed(self, change): self.spin = change['new'] if self.spin: self._view._set_spin([self._spin_x, self._spin_y, self._spin_z], self._spin_speed) else: # stop self._view._set_spin(None, None) @observe('_spin_x') def on_spin_x_changed(self, change): self._spin_x = change['new'] if self.spin: self._view._set_spin([self._spin_x, self._spin_y, self._spin_z], self._spin_speed) @observe('_spin_y') def on_spin_y_changed(self, change): self._spin_y = change['new'] if self.spin: self._view._set_spin([self._spin_x, self._spin_y, self._spin_z], self._spin_speed) @observe('_spin_z') def on_spin_z_changed(self, change): self._spin_z = change['new'] if self.spin: self._view._set_spin([self._spin_x, self._spin_y, self._spin_z], self._spin_speed) @observe('_spin_speed') def on_spin_speed_changed(self, change): self._spin_speed = change['new'] if self.spin: self._view._set_spin([self._spin_x, self._spin_y, self._spin_z], self._spin_speed) def _display(self): box_factory = [(self._make_general_box, 'General'), (self._make_widget_repr, 'Representation'), (self._make_widget_preference, 'Preference'), (self._make_theme_box, 'Theme'), (self._make_extra_box, 'Extra'), (self._show_website, 'Help')] tab = _make_delay_tab(box_factory, selected_index=-1) # tab = _make_autofit(tab) tab.layout.align_self = 'center' tab.layout.align_items = 'stretch' self.widget_tab = tab return self.widget_tab def _make_widget_tab(self): return self._display() def _make_button_center(self): button = Button(description=' Center', icon='fa-bullseye') @button.on_click def on_click(button): self._view.center() return button def _make_button_theme(self): button = Button(description='Oceans16') @button.on_click def on_click(button): from nglview import theme display(theme.oceans16()) self._view._remote_call('cleanOutput', target='Widget') return button def _make_button_reset_theme(self, hide_toolbar=False): from nglview import theme if hide_toolbar: button = Button(description='Simplified Default') @button.on_click def on_click(button): theme.reset(hide_toolbar=True) else: button = Button(description='Default') @button.on_click def on_click(button): theme.reset() return button def _make_button_clean_error_output(self): button = Button(description='Clear Error') @button.on_click def on_click(_): js_utils.clean_error_output() return button def _make_widget_preference(self, width='100%'): def make_func(): parameters = self._view._full_stage_parameters def func(pan_speed=parameters.get('panSpeed', 0.8), rotate_speed=parameters.get('rotateSpeed', 2), zoom_speed=parameters.get('zoomSpeed', 1.2), clip_dist=parameters.get('clipDist', 10), camera_fov=parameters.get('cameraFov', 40), clip_far=parameters.get('clipFar', 100), clip_near=parameters.get('clipNear', 0), fog_far=parameters.get('fogFar', 100), fog_near=parameters.get('fogNear', 50), impostor=parameters.get('impostor', True), light_intensity=parameters.get('lightIntensity', 1), quality=parameters.get('quality', 'medium'), sample_level=parameters.get('sampleLevel', 1)): self._view.parameters = dict(panSpeed=pan_speed, rotateSpeed=rotate_speed, zoomSpeed=zoom_speed, clipDist=clip_dist, clipFar=clip_far, clipNear=clip_near, cameraFov=camera_fov, fogFar=fog_far, fogNear=fog_near, impostor=impostor, lightIntensity=light_intensity, quality=quality, sampleLevel=sample_level) return func def make_widget_box(): widget_sliders = interactive(make_func(), pan_speed=(0, 10, 0.1), rotate_speed=(0, 10, 1), zoom_speed=(0, 10, 1), clip_dist=(0, 200, 5), clip_far=(0, 100, 1), clip_near=(0, 100, 1), camera_fov=(15, 120, 1), fog_far=(0, 100, 1), fog_near=(0, 100, 1), light_intensity=(0, 10, 0.02), quality=['low', 'medium', 'high'], sample_level=(-1, 5, 1)) for child in widget_sliders.children: if isinstance(child, (IntSlider, FloatSlider)): child.layout.width = default.DEFAULT_SLIDER_WIDTH return widget_sliders if self.widget_preference is None: widget_sliders = make_widget_box() reset_button = Button(description='Reset') widget_sliders.children = [ reset_button, ] + list(widget_sliders.children) @reset_button.on_click def on_click(reset_button): self._view.parameters = self._view._original_stage_parameters self._view._full_stage_parameters = self._view._original_stage_parameters widget_sliders.children = [ reset_button, ] + list(make_widget_box().children) self.widget_preference = _relayout_master(widget_sliders, width=width) return self.widget_preference def _show_download_image(self): # "interactive" does not work for True/False in ipywidgets 4 yet. button = Button(description=' Screenshot', icon='fa-camera') @button.on_click def on_click(button): self._view.download_image() return button def _make_button_url(self, url, description): button = Button(description=description) @button.on_click def on_click(button): display(Javascript(js_utils.open_url_template.format(url=url))) return button def _show_website(self, ngl_base_url=default.NGL_BASE_URL): buttons = [ self._make_button_url(url.format(ngl_base_url), description) for url, description in [("'http://arose.github.io/nglview/latest/'", "nglview"), ( "'{}/index.html'", "NGL"), ("'{}/tutorial-selection-language.html'", "Selection"), ("'{}/tutorial-molecular-representations.html'", "Representation")] ] self.widget_help = _make_autofit(HBox(buttons)) return self.widget_help def _make_button_qtconsole(self): from nglview import js_utils button = Button(description='qtconsole', tooltip='pop up qtconsole') @button.on_click def on_click(button): js_utils.launch_qtconsole() return button def _make_text_picked(self): ta = Textarea(value=json.dumps(self._view.picked), description='Picked atom') ta.layout.width = '300px' return ta def _refresh(self, component_slider, repr_slider): """update representation and component information """ self._view._request_repr_parameters(component=component_slider.value, repr_index=repr_slider.value) self._view._remote_call('requestReprInfo', target='Widget') self._view._handle_repr_dict_changed(change=dict( new=self._view._repr_dict)) def _make_button_repr_control(self, component_slider, repr_slider, repr_selection): button_refresh = Button(description=' Refresh', tooltip='Get representation info', icon='fa-refresh') button_center_selection = Button(description=' Center', tooltip='center selected atoms', icon='fa-bullseye') button_center_selection._ngl_name = 'button_center_selection' button_hide = Button(description=' Hide', icon='fa-eye-slash', tooltip='Hide/Show current representation') button_remove = Button(description=' Remove', icon='fa-trash', tooltip='Remove current representation') button_repr_parameter_dialog = Button( description=' Dialog', tooltip='Pop up representation parameters control dialog') @button_refresh.on_click def on_click_refresh(button): self._refresh(component_slider, repr_slider) @button_center_selection.on_click def on_click_center(center_selection): self._view.center_view(selection=repr_selection.value, component=component_slider.value) @button_hide.on_click def on_click_hide(button_hide): component = component_slider.value repr_index = repr_slider.value if button_hide.description == 'Hide': hide = True button_hide.description = 'Show' else: hide = False button_hide.description = 'Hide' self._view._remote_call('setVisibilityForRepr', target='Widget', args=[component, repr_index, not hide]) @button_remove.on_click def on_click_remove(button_remove): self._view._remove_representation(component=component_slider.value, repr_index=repr_slider.value) self._view._request_repr_parameters( component=component_slider.value, repr_index=repr_slider.value) @button_repr_parameter_dialog.on_click def on_click_repr_dialog(_): from nglview.widget_box import DraggableBox if self.widget_repr_parameters is not None and self.widget_repr_choices: self.widget_repr_parameters_dialog = DraggableBox( [self.widget_repr_choices, self.widget_repr_parameters]) self.widget_repr_parameters_dialog._ipython_display_() self.widget_repr_parameters_dialog._dialog = 'on' bbox = _make_autofit( HBox([ button_refresh, button_center_selection, button_hide, button_remove, button_repr_parameter_dialog ])) return bbox def _make_widget_repr(self): self.widget_repr_name = Text(value='', description='representation') self.widget_repr_name._ngl_name = 'repr_name_text' repr_selection = Text(value=' ', description='selection') repr_selection._ngl_name = 'repr_selection' repr_selection.width = self.widget_repr_name.width = default.DEFAULT_TEXT_WIDTH max_n_components = max(self._view.n_components - 1, 0) self.widget_component_slider = IntSlider(value=0, max=max_n_components, min=0, description='component') self.widget_component_slider._ngl_name = 'component_slider' cvalue = ' ' self.widget_component_dropdown = Dropdown(value=cvalue, options=[ cvalue, ], description='component') self.widget_component_dropdown._ngl_name = 'component_dropdown' self.widget_repr_slider = IntSlider(value=0, description='representation', width=default.DEFAULT_SLIDER_WIDTH) self.widget_repr_slider._ngl_name = 'repr_slider' self.widget_repr_slider.visible = True self.widget_component_slider.layout.width = default.DEFAULT_SLIDER_WIDTH self.widget_repr_slider.layout.width = default.DEFAULT_SLIDER_WIDTH self.widget_component_dropdown.layout.width = self.widget_component_dropdown.max_width = default.DEFAULT_TEXT_WIDTH # turn off for now self.widget_component_dropdown.layout.display = 'none' self.widget_component_dropdown.description = '' # self.widget_accordion_repr_parameters = Accordion() self.widget_accordion_repr_parameters = Tab() self.widget_repr_parameters = self._make_widget_repr_parameters( self.widget_component_slider, self.widget_repr_slider, self.widget_repr_name) self.widget_accordion_repr_parameters.children = [ self.widget_repr_parameters, Box() ] self.widget_accordion_repr_parameters.set_title(0, 'Parameters') self.widget_accordion_repr_parameters.set_title(1, 'Hide') self.widget_accordion_repr_parameters.selected_index = 1 checkbox_reprlist = Checkbox(value=False, description='reprlist') checkbox_reprlist._ngl_name = 'checkbox_reprlist' self.widget_repr_choices = self._make_repr_name_choices( self.widget_component_slider, self.widget_repr_slider) self.widget_repr_choices._ngl_name = 'reprlist_choices' self.widget_repr_add = self._make_add_widget_repr( self.widget_component_slider) def on_update_checkbox_reprlist(change): self.widget_repr_choices.visible = change['new'] checkbox_reprlist.observe(on_update_checkbox_reprlist, names='value') def on_repr_name_text_value_changed(change): name = change['new'].strip() old = change['old'].strip() should_update = (self._real_time_update and old and name and name in REPRESENTATION_NAMES and name != change['old'].strip()) if should_update: component = self.widget_component_slider.value repr_index = self.widget_repr_slider.value self._view._remote_call( 'setRepresentation', target='Widget', args=[change['new'], {}, component, repr_index]) self._view._request_repr_parameters(component, repr_index) def on_component_or_repr_slider_value_changed(change): self._view._request_repr_parameters( component=self.widget_component_slider.value, repr_index=self.widget_repr_slider.value) self.widget_component_dropdown.options = tuple( self._view._ngl_component_names) if self.widget_accordion_repr_parameters.selected_index >= 0: self.widget_repr_parameters.name = self.widget_repr_name.value self.widget_repr_parameters.repr_index = self.widget_repr_slider.value self.widget_repr_parameters.component_index = self.widget_component_slider.value def on_repr_selection_value_changed(change): if self._real_time_update: component = self.widget_component_slider.value repr_index = self.widget_repr_slider.value self._view._set_selection(change['new'], component=component, repr_index=repr_index) def on_change_component_dropdown(change): choice = change['new'] if choice: self.widget_component_slider.value = self._view._ngl_component_names.index( choice) self.widget_component_dropdown.observe(on_change_component_dropdown, names='value') self.widget_repr_slider.observe( on_component_or_repr_slider_value_changed, names='value') self.widget_component_slider.observe( on_component_or_repr_slider_value_changed, names='value') self.widget_repr_name.observe(on_repr_name_text_value_changed, names='value') repr_selection.observe(on_repr_selection_value_changed, names='value') self.widget_repr_control_buttons = self._make_button_repr_control( self.widget_component_slider, self.widget_repr_slider, repr_selection) blank_box = Box([Label("")]) all_kids = [ self.widget_repr_control_buttons, blank_box, self.widget_repr_add, self.widget_component_dropdown, self.widget_repr_name, repr_selection, self.widget_component_slider, self.widget_repr_slider, self.widget_repr_choices, self.widget_accordion_repr_parameters ] vbox = VBox(all_kids) self._view._request_repr_parameters( component=self.widget_component_slider.value, repr_index=self.widget_repr_slider.value) self.widget_repr = _relayout_master(vbox, width='100%') self._refresh(self.widget_component_slider, self.widget_repr_slider) setattr(self.widget_repr, "_saved_widgets", []) for _box in self.widget_repr.children: if hasattr(_box, 'children'): for kid in _box.children: self.widget_repr._saved_widgets.append(kid) return self.widget_repr def _make_widget_repr_parameters(self, component_slider, repr_slider, repr_name_text=None): name = repr_name_text.value if repr_name_text is not None else ' ' widget = self._view._display_repr(component=component_slider.value, repr_index=repr_slider.value, name=name) widget._ngl_name = 'repr_parameters_box' return widget def _make_button_export_image(self): slider_factor = IntSlider(value=4, min=1, max=10, description='scale') checkbox_antialias = Checkbox(value=True, description='antialias') checkbox_trim = Checkbox(value=False, description='trim') checkbox_transparent = Checkbox(value=False, description='transparent') filename_text = Text(value='Screenshot', description='Filename') delay_text = FloatText(value=1, description='delay (s)', tooltip='hello') start_text, stop_text, step_text = (IntText(value=0, description='start'), IntText(value=self._view.count, description='stop'), IntText(value=1, description='step')) start_text.layout.max_width = stop_text.layout.max_width = step_text.layout.max_width \ = filename_text.layout.max_width = delay_text.layout.max_width = default.DEFAULT_TEXT_WIDTH button_movie_images = Button(description='Export Images') def download_image(filename): self._view.download_image(factor=slider_factor.value, antialias=checkbox_antialias.value, trim=checkbox_trim.value, transparent=checkbox_transparent.value, filename=filename) @button_movie_images.on_click def on_click_images(button_movie_images): for i in range(start_text.value, stop_text.value, step_text.value): self._view.frame = i time.sleep(delay_text.value) download_image(filename=filename_text.value + str(i)) time.sleep(delay_text.value) vbox = VBox([ button_movie_images, start_text, stop_text, step_text, delay_text, filename_text, slider_factor, checkbox_antialias, checkbox_trim, checkbox_transparent, ]) form_items = _relayout(vbox, make_form_item_layout()) form = Box(form_items, layout=_make_box_layout()) # form = _relayout_master(vbox) return form def _make_resize_notebook_slider(self): resize_notebook_slider = IntSlider(min=300, max=2000, description='resize notebook') def on_resize_notebook(change): width = change['new'] self._view._remote_call('resizeNotebook', target='Widget', args=[ width, ]) resize_notebook_slider.observe(on_resize_notebook, names='value') return resize_notebook_slider def _make_add_widget_repr(self, component_slider): dropdown_repr_name = Dropdown(options=REPRESENTATION_NAMES, value='cartoon') repr_selection = Text(value='*', description='') repr_button = Button(description='Add', tooltip="""Add representation. You can also hit Enter in selection box""") repr_button.layout = Layout(width='auto', flex='1 1 auto') dropdown_repr_name.layout.width = repr_selection.layout.width = default.DEFAULT_TEXT_WIDTH def on_click_or_submit(button_or_text_area): self._view.add_representation( selection=repr_selection.value.strip(), repr_type=dropdown_repr_name.value, component=component_slider.value) repr_button.on_click(on_click_or_submit) repr_selection.on_submit(on_click_or_submit) add_repr_box = HBox([repr_button, dropdown_repr_name, repr_selection]) add_repr_box._ngl_name = 'add_repr_box' return add_repr_box def _make_repr_playground(self): vbox = VBox() children = [] rep_names = REPRESENTATION_NAMES[:] excluded_names = ['ball+stick', 'distance'] for name in excluded_names: rep_names.remove(name) repr_selection = Text(value='*') repr_selection.layout.width = default.DEFAULT_TEXT_WIDTH repr_selection_box = HBox([Label('selection'), repr_selection]) setattr(repr_selection_box, 'value', repr_selection.value) for index, name in enumerate(rep_names): button = ToggleButton(description=name) def make_func(): def on_toggle_button_value_change(change, button=button): selection = repr_selection.value new = change['new'] # True/False if new: self._view.add_representation(button.description, selection=selection) else: self._view._remove_representations_by_name( button.description) return on_toggle_button_value_change button.observe(make_func(), names='value') children.append(button) button_clear = Button(description='clear', button_style='info', icon='fa-eraser') @button_clear.on_click def on_clear(button_clear): self._view.clear() for kid in children: # unselect kid.value = False vbox.children = children + [repr_selection, button_clear] _make_autofit(vbox) self.widget_quick_repr = vbox return self.widget_quick_repr def _make_repr_name_choices(self, component_slider, repr_slider): repr_choices = Dropdown(options=[ " ", ]) def on_chosen(change): repr_name = change.get('new', " ") try: repr_index = repr_choices.options.index(repr_name) repr_slider.value = repr_index except ValueError: pass repr_choices.observe(on_chosen, names='value') repr_choices.layout.width = default.DEFAULT_TEXT_WIDTH self.widget_repr_choices = repr_choices return self.widget_repr_choices def _make_drag_widget(self): button_drag = Button(description='widget drag: off', tooltip='dangerous') drag_nb = Button(description='notebook drag: off', tooltip='dangerous') button_reset_notebook = Button(description='notebook: reset', tooltip='reset?') button_dialog = Button(description='dialog', tooltip='make a dialog') button_split_half = Button(description='split screen', tooltip='try best to make a good layout') @button_drag.on_click def on_drag(button_drag): if button_drag.description == 'widget drag: off': self._view._set_draggable(True) button_drag.description = 'widget drag: on' else: self._view._set_draggable(False) button_drag.description = 'widget drag: off' @drag_nb.on_click def on_drag_nb(button_drag): if drag_nb.description == 'notebook drag: off': js_utils._set_notebook_draggable(True) drag_nb.description = 'notebook drag: on' else: js_utils._set_notebook_draggable(False) drag_nb.description = 'notebook drag: off' @button_reset_notebook.on_click def on_reset(button_reset_notebook): js_utils._reset_notebook() @button_dialog.on_click def on_dialog(button_dialog): self._view._remote_call('setDialog', target='Widget') @button_split_half.on_click def on_split_half(button_dialog): from nglview import js_utils import time js_utils._move_notebook_to_the_left() js_utils._set_notebook_width('5%') time.sleep(0.1) self._view._remote_call('setDialog', target='Widget') drag_box = HBox([ button_drag, drag_nb, button_reset_notebook, button_dialog, button_split_half ]) drag_box = _make_autofit(drag_box) self.widget_drag = drag_box return drag_box def _make_spin_box(self): checkbox_spin = Checkbox(self.spin, description='spin') spin_x_slide = IntSlider(self._spin_x, min=-1, max=1, description='spin_x') spin_y_slide = IntSlider(self._spin_y, min=-1, max=1, description='spin_y') spin_z_slide = IntSlider(self._spin_z, min=-1, max=1, description='spin_z') spin_speed_slide = FloatSlider(self._spin_speed, min=0, max=0.2, step=0.001, description='spin speed') # spin link((checkbox_spin, 'value'), (self, 'spin')) link((spin_x_slide, 'value'), (self, '_spin_x')) link((spin_y_slide, 'value'), (self, '_spin_y')) link((spin_z_slide, 'value'), (self, '_spin_z')) link((spin_speed_slide, 'value'), (self, '_spin_speed')) spin_box = VBox([ checkbox_spin, spin_x_slide, spin_y_slide, spin_z_slide, spin_speed_slide ]) spin_box = _relayout_master(spin_box, width='75%') return spin_box def _make_widget_picked(self): self.widget_picked = self._make_text_picked() picked_box = HBox([ self.widget_picked, ]) return _relayout_master(picked_box, width='75%') def _make_export_image_widget(self): if self.widget_export_image is None: self.widget_export_image = HBox([self._make_button_export_image()]) return self.widget_export_image def _make_extra_box(self): if self.widget_extra is None: extra_list = [(self._make_drag_widget, 'Drag'), (self._make_spin_box, 'Spin'), (self._make_widget_picked, 'Picked'), (self._make_repr_playground, 'Quick'), (self._make_export_image_widget, 'Image'), (self._make_command_box, 'Command')] extra_box = _make_delay_tab(extra_list, selected_index=0) self.widget_extra = extra_box return self.widget_extra def _make_theme_box(self): if self.widget_theme is None: self.widget_theme = Box([ self._make_button_theme(), self._make_button_reset_theme(hide_toolbar=False), self._make_button_reset_theme(hide_toolbar=True), self._make_button_clean_error_output() ]) return self.widget_theme def _make_general_box(self): if self.widget_general is None: step_slide = IntSlider(value=self.step, min=-100, max=100, description='step') delay_text = IntSlider(value=self.delay, min=10, max=1000, description='delay') toggle_button_interpolate = ToggleButton( self.interpolate, description='Smoothing', tooltip='smoothing trajectory') link((toggle_button_interpolate, 'value'), (self, 'interpolate')) background_color_picker = ColorPicker(value='white', description='background') camera_type = Dropdown(value=self.camera, options=['perspective', 'orthographic'], description='camera') link((step_slide, 'value'), (self, 'step')) link((delay_text, 'value'), (self, 'delay')) link((toggle_button_interpolate, 'value'), (self, 'interpolate')) link((camera_type, 'value'), (self, 'camera')) link((background_color_picker, 'value'), (self._view, 'background')) center_button = self._make_button_center() render_button = self._show_download_image() qtconsole_button = self._make_button_qtconsole() center_render_hbox = _make_autofit( HBox([ toggle_button_interpolate, center_button, render_button, qtconsole_button ])) v0_left = VBox([ step_slide, delay_text, background_color_picker, camera_type, center_render_hbox, ]) v0_left = _relayout_master(v0_left, width='100%') self.widget_general = v0_left return self.widget_general def _make_command_box(self): widget_text_command = Text() @widget_text_command.on_submit def _on_submit_command(_): command = widget_text_command.value js_utils.execute(command) widget_text_command.value = '' return widget_text_command def _create_all_tabs(self): tab = self._display() for index, _ in enumerate(tab.children): # trigger ceating widgets tab.selected_index = index self.widget_extra = self._make_extra_box() for index, _ in enumerate(self.widget_extra.children): self.widget_extra.selected_index = index def _simplify_repr_control(self): for widget in self.widget_repr._saved_widgets: if not isinstance(widget, Tab): widget.layout.display = 'none' self.widget_repr_choices.layout.display = 'flex' self.widget_accordion_repr_parameters.selected_index = 0
class LDAPAuthenticator(Authenticator): server_address = Unicode(config=True, help=""" Address of the LDAP server to contact. Could be an IP address or hostname. """) server_port = Int(config=True, help=""" Port on which to contact the LDAP server. Defaults to `636` if `use_ssl` is set, `389` otherwise. """) def _server_port_default(self): if self.use_ssl: return 636 # default SSL port for LDAP else: return 389 # default plaintext port for LDAP use_ssl = Bool(True, config=True, help=""" Use SSL to communicate with the LDAP server. Highly recommended! Your LDAP server must be configured to support this, however. """) bind_dn_template = Union([List(), Unicode()], config=True, help=""" Template from which to construct the full dn when authenticating to LDAP. {username} is replaced with the actual username used to log in. If your LDAP is set in such a way that the userdn can not be formed from a template, but must be looked up with an attribute (such as uid or sAMAccountName), please see `lookup_dn`. It might be particularly relevant for ActiveDirectory installs. Unicode Example: uid={username},ou=people,dc=wikimedia,dc=org List Example: [ uid={username},ou=people,dc=wikimedia,dc=org, uid={username},ou=Developers,dc=wikimedia,dc=org ] """) allowed_groups = List(config=True, allow_none=True, default=None, help=""" List of LDAP group DNs that users could be members of to be granted access. If a user is in any one of the listed groups, then that user is granted access. Membership is tested by fetching info about each group and looking for the User's dn to be a value of one of `member` or `uniqueMember`, *or* if the username being used to log in with is value of the `uid`. Set to an empty list or None to allow all users that have an LDAP account to log in, without performing any group membership checks. """) # FIXME: Use something other than this? THIS IS LAME, akin to websites restricting things you # can use in usernames / passwords to protect from SQL injection! valid_username_regex = Unicode(r'^[a-z][.a-z0-9_-]*$', config=True, help=""" Regex for validating usernames - those that do not match this regex will be rejected. This is primarily used as a measure against LDAP injection, which has fatal security considerations. The default works for most LDAP installations, but some users might need to modify it to fit their custom installs. If you are modifying it, be sure to understand the implications of allowing additional characters in usernames and what that means for LDAP injection issues. See https://www.owasp.org/index.php/LDAP_injection for an overview of LDAP injection. """) lookup_dn = Bool(False, config=True, help=""" Form user's DN by looking up an entry from directory By default, LDAPAuthenticator finds the user's DN by using `bind_dn_template`. However, in some installations, the user's DN does not contain the username, and hence needs to be looked up. You can set this to True and then use `user_search_base` and `user_attribute` to accomplish this. """) user_search_base = Unicode(config=True, default=None, allow_none=True, help=""" Base for looking up user accounts in the directory, if `lookup_dn` is set to True. LDAPAuthenticator will search all objects matching under this base where the `user_attribute` is set to the current username to form the userdn. For example, if all users objects existed under the base ou=people,dc=wikimedia,dc=org, and the username users use is set with the attribute `uid`, you can use the following config: ``` c.LDAPAuthenticator.lookup_dn = True c.LDAPAuthenticator.lookup_dn_search_filter = '({login_attr}={login})' c.LDAPAuthenticator.lookup_dn_search_user = '******' c.LDAPAuthenticator.lookup_dn_search_password = '******' c.LDAPAuthenticator.user_search_base = 'ou=people,dc=wikimedia,dc=org' c.LDAPAuthenticator.user_attribute = 'sAMAccountName' c.LDAPAuthenticator.lookup_dn_user_dn_attribute = 'cn' ``` """) user_attribute = Unicode(config=True, default=None, allow_none=True, help=""" Attribute containing user's name, if `lookup_dn` is set to True. See `user_search_base` for info on how this attribute is used. For most LDAP servers, this is uid. For Active Directory, it is sAMAccountName. """) lookup_dn_search_filter = Unicode(config=True, default_value='({login_attr}={login})', allow_none=True, help=""" How to query LDAP for user name lookup, if `lookup_dn` is set to True. """) lookup_dn_search_user = Unicode(config=True, default_value=None, allow_none=True, help=""" Technical account for user lookup, if `lookup_dn` is set to True. If both lookup_dn_search_user and lookup_dn_search_password are None, then anonymous LDAP query will be done. """) lookup_dn_search_password = Unicode(config=True, default_value=None, allow_none=True, help=""" Technical account for user lookup, if `lookup_dn` is set to True. """) lookup_dn_user_dn_attribute = Unicode(config=True, default_value=None, allow_none=True, help=""" Attribute containing user's name needed for building DN string, if `lookup_dn` is set to True. See `user_search_base` for info on how this attribute is used. For most LDAP servers, this is username. For Active Directory, it is cn. """) escape_userdn = Bool(False, config=True, help=""" If set to True, escape special chars in userdn when authenticating in LDAP. On some LDAP servers, when userdn contains chars like '(', ')', '\' authentication may fail when those chars are not escaped. """) def resolve_username(self, username_supplied_by_user): if self.lookup_dn: server = ldap3.Server(self.server_address, port=self.server_port, use_ssl=self.use_ssl) search_filter = self.lookup_dn_search_filter.format( login_attr=self.user_attribute, login=username_supplied_by_user) self.log.debug( "Looking up user with search_base={search_base}, search_filter='{search_filter}', attributes={attributes}" .format(search_base=self.user_search_base, search_filter=search_filter, attributes=self.user_attribute)) conn = ldap3.Connection(server, user=self.escape_userdn_if_needed( self.lookup_dn_search_user), password=self.lookup_dn_search_password) is_bound = conn.bind() if not is_bound: self.log.warn( "Can't connect to LDAP. Server={server}. User={user}. Password={password}" .format(server=self.server_address, user=self.escape_userdn_if_needed( self.lookup_dn_search_user), password=self.lookup_dn_search_password)) return None conn.search(search_base=self.user_search_base, search_scope=ldap3.SUBTREE, search_filter=search_filter, attributes=[self.lookup_dn_user_dn_attribute]) if len(conn.response ) == 0 or 'attributes' not in conn.response[0].keys(): self.log.warn( 'username:%s No such user entry found when looking up with attribute %s', username_supplied_by_user, self.user_attribute) return None return conn.response[0]['attributes'][ self.lookup_dn_user_dn_attribute] else: return username_supplied_by_user def escape_userdn_if_needed(self, userdn): if self.escape_userdn: return escape_filter_chars(userdn) else: return userdn @gen.coroutine def authenticate(self, handler, data): username = data['username'] password = data['password'] # Get LDAP Connection def getConnection(userdn, username, password): server = ldap3.Server(self.server_address, port=self.server_port, use_ssl=self.use_ssl) self.log.debug( 'Attempting to bind {username} with {userdn}'.format( username=username, userdn=userdn)) conn = ldap3.Connection(server, user=self.escape_userdn_if_needed(userdn), password=password) return conn # Protect against invalid usernames as well as LDAP injection attacks if not re.match(self.valid_username_regex, username): self.log.warn( 'username:%s Illegal characters in username, must match regex %s', username, self.valid_username_regex) return None # No empty passwords! if password is None or password.strip() == '': self.log.warn('username:%s Login denied for blank password', username) return None isBound = False self.log.debug("TYPE= '%s'", isinstance(self.bind_dn_template, list)) resolved_username = self.resolve_username(username) if resolved_username is None: return None # In case, there are multiple binding templates if isinstance(self.bind_dn_template, list): for dn in self.bind_dn_template: userdn = dn.format(username=resolved_username) conn = getConnection(userdn, username, password) isBound = conn.bind() self.log.debug( 'Status of user bind {username} with {userdn} : {isBound}'. format(username=username, userdn=userdn, isBound=isBound)) if isBound: break else: userdn = self.bind_dn_template.format(username=resolved_username) conn = getConnection(userdn, username, password) isBound = conn.bind() if isBound: if self.allowed_groups: self.log.debug('username:%s Using dn %s', username, userdn) for group in self.allowed_groups: groupfilter = ('(|' '(member={userdn})' '(uniqueMember={userdn})' '(memberUid={uid})' ')').format( userdn=escape_filter_chars(userdn), uid=escape_filter_chars(username)) groupattributes = ['member', 'uniqueMember', 'memberUid'] if conn.search(group, search_scope=ldap3.BASE, search_filter=groupfilter, attributes=groupattributes): return username # If we reach here, then none of the groups matched self.log.warn( 'username:%s User not in any of the allowed groups', username) return None else: return username else: self.log.warn('Invalid password for user {username}'.format( username=userdn, )) return None
class Circle(CircleMarker): _view_name = Unicode('LeafletCircleView').tag(sync=True) _model_name = Unicode('LeafletCircleModel').tag(sync=True) # Options radius = Int(1000, help='radius of circle in meters').tag(sync=True, o=True)
class EventAnimationCreator(Tool): name = "EventAnimationCreator" description = "Create an animation of the camera image through timeslices" req_event = Int(0, help='Event to plot').tag(config=True) aliases = Dict( dict(r='EventFileReaderFactory.reader', f='EventFileReaderFactory.input_path', max_events='EventFileReaderFactory.max_events', ped='CameraR1CalibratorFactory.pedestal_path', tf='CameraR1CalibratorFactory.tf_path', pe='CameraR1CalibratorFactory.pe_path', cleaner='WaveformCleanerFactory.cleaner', e='EventAnimationCreator.req_event', start='Animator.start', end='Animator.end', p1='Animator.p1', p2='Animator.p2')) classes = List([ EventFileReaderFactory, CameraR1CalibratorFactory, WaveformCleanerFactory, CHECMSPEFitter, Animator ]) def __init__(self, **kwargs): super().__init__(**kwargs) self.reader = None self.r1 = None self.dl0 = None self.cleaner = None self.extractor = None self.dl1 = None self.fitter = None self.dead = None self.adc2pe = None self.animator = None def setup(self): self.log_format = "%(levelname)s: %(message)s [%(name)s.%(funcName)s]" kwargs = dict(config=self.config, tool=self) reader_factory = EventFileReaderFactory(**kwargs) reader_class = reader_factory.get_class() self.reader = reader_class(**kwargs) r1_factory = CameraR1CalibratorFactory(origin=self.reader.origin, **kwargs) r1_class = r1_factory.get_class() self.r1 = r1_class(**kwargs) cleaner_factory = WaveformCleanerFactory(**kwargs) cleaner_class = cleaner_factory.get_class() self.cleaner = cleaner_class(**kwargs) extractor_factory = ChargeExtractorFactory(**kwargs) extractor_class = extractor_factory.get_class() self.extractor = extractor_class(**kwargs) self.dl0 = CameraDL0Reducer(**kwargs) self.dl1 = CameraDL1Calibrator(extractor=self.extractor, cleaner=self.cleaner, **kwargs) self.fitter = CHECMSPEFitter(**kwargs) self.dead = Dead() self.animator = Animator(**kwargs) def start(self): event = self.reader.get_event(self.req_event) telid = list(event.r0.tels_with_data)[0] geom = CameraGeometry.guess(*event.inst.pixel_pos[0], event.inst.optical_foclen[0]) self.r1.calibrate(event) self.dl0.reduce(event) self.dl1.calibrate(event) cleaned = event.dl1.tel[telid].cleaned[0] output_dir = self.reader.output_directory self.animator.plot(cleaned, geom, self.req_event, output_dir) def finish(self): pass
class Animator(Component): name = 'Animator' start = Int(40, help='Time to start gif').tag(config=True) end = Int(8, help='Time to end gif').tag(config=True) p1 = Int(0, help='Pixel 1').tag(config=True) p2 = Int(0, help='Pixel 2').tag(config=True) def __init__(self, config, tool, **kwargs): super().__init__(config=config, parent=tool, **kwargs) self.fig = plt.figure(figsize=(24, 10)) self.ax1 = self.fig.add_subplot(2, 2, 1) self.ax2 = self.fig.add_subplot(2, 2, 3) self.camera = self.fig.add_subplot(1, 2, 2) def plot(self, waveforms, geom, event_id, output_dir): camera = CameraDisplay(geom, ax=self.camera, image=np.zeros(2048), cmap='viridis') camera.add_colorbar() max_ = np.percentile(waveforms[:, self.start:self.end].max(), 60) camera.set_limits_minmax(0, max_) self.ax1.plot(waveforms[self.p1, :]) self.ax2.plot(waveforms[self.p2, :]) self.fig.suptitle("Event {}".format(event_id)) self.ax1.set_title("Pixel: {}".format(self.p1)) self.ax1.set_xlabel("Time (ns)") self.ax1.set_ylabel("Amplitude (p.e.)") self.ax2.set_title("Pixel: {}".format(self.p2)) self.ax2.set_xlabel("Time (ns)") self.ax2.set_ylabel("Amplitude (p.e.)") camera.colorbar.set_label("Amplitude (p.e.)") line1, = self.ax1.plot([0, 0], self.ax1.get_ylim(), color='r', alpha=1) line2, = self.ax2.plot([0, 0], self.ax2.get_ylim(), color='r', alpha=1) self.camera.annotate("Pixel: {}".format(self.p1), xy=(geom.pix_x.value[self.p1], geom.pix_y.value[self.p1]), xycoords='data', xytext=(0.05, 0.98), textcoords='axes fraction', arrowprops=dict(facecolor='red', width=2, alpha=0.4), horizontalalignment='left', verticalalignment='top') self.camera.annotate("Pixel: {}".format(self.p2), xy=(geom.pix_x.value[self.p2], geom.pix_y.value[self.p2]), xycoords='data', xytext=(0.05, 0.94), textcoords='axes fraction', arrowprops=dict(facecolor='orange', width=2, alpha=0.4), horizontalalignment='left', verticalalignment='top') # Create animation div = 5 increment = 1 / div n_frames = int((self.end - self.start) / increment) interval = int(500 * increment) # Prepare Output output_path = join(output_dir, "animation_e{}.gif".format(event_id)) if not exists(output_dir): self.log.info("Creating directory: {}".format(output_dir)) makedirs(output_dir) self.log.info("Output: {}".format(output_path)) with tqdm(total=n_frames, desc="Creating animation") as pbar: def animate(i): pbar.update(1) t = self.start + (i / div) camera.image = waveforms[:, int(t)] line1.set_xdata(t) line2.set_xdata(t) anim = animation.FuncAnimation(self.fig, animate, frames=n_frames, interval=interval) anim.save(output_path, writer='imagemagick') self.log.info("Created animation: {}".format(output_path))
class Map(DOMWidget, InteractMixin): @default('layout') def _default_layout(self): return Layout(height='400px', align_self='stretch') _view_name = Unicode('LeafletMapView').tag(sync=True) _model_name = Unicode('LeafletMapModel').tag(sync=True) _view_module = Unicode('jupyter-leaflet').tag(sync=True) _model_module = Unicode('jupyter-leaflet').tag(sync=True) # Map options center = List(def_loc).tag(sync=True, o=True) zoom_start = Int(12).tag(sync=True, o=True) zoom = Int(12).tag(sync=True, o=True) max_zoom = Int(18).tag(sync=True, o=True) min_zoom = Int(1).tag(sync=True, o=True) # Interaction options dragging = Bool(True).tag(sync=True, o=True) touch_zoom = Bool(True).tag(sync=True, o=True) scroll_wheel_zoom = Bool(False).tag(sync=True, o=True) double_click_zoom = Bool(True).tag(sync=True, o=True) box_zoom = Bool(True).tag(sync=True, o=True) tap = Bool(True).tag(sync=True, o=True) tap_tolerance = Int(15).tag(sync=True, o=True) world_copy_jump = Bool(False).tag(sync=True, o=True) close_popup_on_click = Bool(True).tag(sync=True, o=True) bounce_at_zoom_limits = Bool(True).tag(sync=True, o=True) keyboard = Bool(True).tag(sync=True, o=True) keyboard_pan_offset = Int(80).tag(sync=True, o=True) keyboard_zoom_offset = Int(1).tag(sync=True, o=True) inertia = Bool(True).tag(sync=True, o=True) inertia_deceleration = Int(3000).tag(sync=True, o=True) inertia_max_speed = Int(1500).tag(sync=True, o=True) # inertia_threshold = Int(?, o=True).tag(sync=True) zoom_control = Bool(True).tag(sync=True, o=True) attribution_control = Bool(True).tag(sync=True, o=True) # fade_animation = Bool(?).tag(sync=True, o=True) # zoom_animation = Bool(?).tag(sync=True, o=True) zoom_animation_threshold = Int(4).tag(sync=True, o=True) # marker_zoom_animation = Bool(?).tag(sync=True, o=True) options = List(trait=Unicode).tag(sync=True) @default('options') def _default_options(self): return [name for name in self.traits(o=True)] _south = Float(def_loc[0]).tag(sync=True) _north = Float(def_loc[0]).tag(sync=True) _east = Float(def_loc[1]).tag(sync=True) _west = Float(def_loc[1]).tag(sync=True) default_tiles = Instance(TileLayer, allow_none=True) @default('default_tiles') def _default_tiles(self): return TileLayer() @property def north(self): return self._north @property def south(self): return self._south @property def east(self): return self._east @property def west(self): return self._west @property def bounds_polygon(self): return [(self.north, self.west), (self.north, self.east), (self.south, self.east), (self.south, self.west)] @property def bounds(self): return [(self.south, self.west), (self.north, self.east)] def __init__(self, **kwargs): super(Map, self).__init__(**kwargs) self.on_displayed(self._fire_children_displayed) if self.default_tiles is not None: self.layers = (self.default_tiles, ) self.on_msg(self._handle_leaflet_event) def _fire_children_displayed(self, widget, **kwargs): for layer in self.layers: layer._handle_displayed(**kwargs) for control in self.controls: control._handle_displayed(**kwargs) layers = Tuple(trait=Instance(Layer)).tag(sync=True, **widget_serialization) layer_ids = List() @validate('layers') def _validate_layers(self, proposal): """Validate layers list. Makes sure only one instance of any given layer can exist in the layers list. """ self.layer_ids = [l.model_id for l in proposal['value']] if len(set(self.layer_ids)) != len(self.layer_ids): raise LayerException( 'duplicate layer detected, only use each layer once') return proposal['value'] def add_layer(self, layer): if layer.model_id in self.layer_ids: raise LayerException('layer already on map: %r' % layer) layer._map = self self.layers = tuple([l for l in self.layers] + [layer]) layer.visible = True def remove_layer(self, layer): if layer.model_id not in self.layer_ids: raise LayerException('layer not on map: %r' % layer) self.layers = tuple( [l for l in self.layers if l.model_id != layer.model_id]) layer.visible = False def clear_layers(self): self.layers = () controls = Tuple(trait=Instance(Control)).tag(sync=True, **widget_serialization) control_ids = List() @validate('controls') def _validate_controls(self, proposal): """Validate controls list. Makes sure only one instance of any given layer can exist in the controls list. """ self.control_ids = [c.model_id for c in proposal['value']] if len(set(self.control_ids)) != len(self.control_ids): raise ControlException( 'duplicate control detected, only use each control once') return proposal['value'] def add_control(self, control): if control.model_id in self.control_ids: raise ControlException('control already on map: %r' % control) control._map = self self.controls = tuple([c for c in self.controls] + [control]) control.visible = True def remove_control(self, control): if control.model_id not in self.control_ids: raise ControlException('control not on map: %r' % control) self.controls = tuple( [c for c in self.controls if c.model_id != control.model_id]) control.visible = False def clear_controls(self): self.controls = () def __iadd__(self, item): if isinstance(item, Layer): self.add_layer(item) elif isinstance(item, Control): self.add_control(item) return self def __isub__(self, item): if isinstance(item, Layer): self.remove_layer(item) elif isinstance(item, Control): self.remove_control(item) return self def __add__(self, item): if isinstance(item, Layer): self.add_layer(item) elif isinstance(item, Control): self.add_control(item) return self def _handle_leaflet_event(self, _, content): pass
class CircleMarker(Circle): _view_name = Unicode('LeafletCircleMarkerView').tag(sync=True) _model_name = Unicode('LeafletCircleMarkerModel').tag(sync=True) radius = Int(10, help="radius of circle in pixels").tag(sync=True)
class Circle(Path): _view_name = Unicode('LeafletCircleView').tag(sync=True) _model_name = Unicode('LeafletCircleModel').tag(sync=True) location = List(def_loc).tag(sync=True) radius = Int(1000, help="radius of circle in meters").tag(sync=True)
class DataGenerator(Configurable): #params for data generator max_q_len = Int(10, help='max q len').tag(config=True) max_d_len = Int(500, help='max document len').tag(config=True) q_name = Unicode('q') d_name = Unicode('d') q_str_name = Unicode('q_str') q_lens_name = Unicode('q_lens') aux_d_name = Unicode('d_aux') vocabulary_size = Int(2000000).tag(config=True) def __init__(self, **kwargs): #init the data generator super(DataGenerator, self).__init__(**kwargs) print ("generator's vocabulary size: ", self.vocabulary_size) def pairwise_reader(self, pair_stream, batch_size, with_idf=False): #generate the batch of x,y in training time l_q = [] l_q_str = [] l_d = [] l_d_aux = [] l_y = [] l_q_lens = [] for line in pair_stream: cols = line.strip().split('\t') y = float(1.0) l_q_str.append(cols[0]) q = np.array([int(t) for t in cols[0].split(',') if int(t) < self.vocabulary_size]) t1 = np.array([int(t) for t in cols[1].split(',') if int(t) < self.vocabulary_size]) t2 = np.array([int(t) for t in cols[2].split(',') if int(t) < self.vocabulary_size]) #padding v_q = np.zeros(self.max_q_len) v_d = np.zeros(self.max_d_len) v_d_aux = np.zeros(self.max_d_len) v_q[:min(q.shape[0], self.max_q_len)] = q[:min(q.shape[0], self.max_q_len)] v_d[:min(t1.shape[0], self.max_d_len)] = t1[:min(t1.shape[0], self.max_d_len)] v_d_aux[:min(t2.shape[0], self.max_d_len)] = t2[:min(t2.shape[0], self.max_d_len)] l_q.append(v_q) l_d.append(v_d) l_d_aux.append(v_d_aux) l_y.append(y) l_q_lens.append(len(q)) if len(l_q) >= batch_size: Q = np.array(l_q, dtype=int,) D = np.array(l_d, dtype=int,) D_aux = np.array(l_d_aux, dtype=int,) Q_lens = np.array(l_q_lens, dtype=int,) Y = np.array(l_y, dtype=int,) X = {self.q_name: Q, self.d_name: D, self.aux_d_name: D_aux, self.q_lens_name: Q_lens, self.q_str_name: l_q_str} yield X, Y l_q, l_d, l_d_aux, l_y, l_q_lens, l_ids, l_q_str = [], [], [], [], [], [], [] if l_q: Q = np.array(l_q, dtype=int,) D = np.array(l_d, dtype=int,) D_aux = np.array(l_d_aux, dtype=int,) Q_lens = np.array(l_q_lens, dtype=int,) Y = np.array(l_y, dtype=int,) X = {self.q_name: Q, self.d_name: D, self.aux_d_name: D_aux, self.q_lens_name: Q_lens, self.q_str_name: l_q_str} yield X, Y def test_pairwise_reader(self, pair_stream, batch_size): #generate the batch of x,y in test time l_q = [] l_q_lens = [] l_d = [] for line in pair_stream: cols = line.strip().split('\t') q = np.array([int(t) for t in cols[0].split(',') if int(t) < self.vocabulary_size]) t = np.array([int(t) for t in cols[1].split(',') if int(t) < self.vocabulary_size]) v_q = np.zeros(self.max_q_len) v_d = np.zeros(self.max_d_len) v_q[:min(q.shape[0], self.max_q_len)] = q[:min(q.shape[0], self.max_q_len)] v_d[:min(t.shape[0], self.max_d_len)] = t[:min(t.shape[0], self.max_d_len)] l_q.append(v_q) l_d.append(v_d) l_q_lens.append(len(q)) if len(l_q) >= batch_size: Q = np.array(l_q, dtype=int,) D = np.array(l_d, dtype=int,) Q_lens = np.array(l_q_lens, dtype=int,) X = {self.q_name: Q, self.d_name: D, self.q_lens_name: Q_lens} yield X l_q, l_d, l_q_lens = [], [], [] if l_q: Q = np.array(l_q, dtype=int,) D = np.array(l_d, dtype=int,) Q_lens = np.array(l_q_lens, dtype=int,) X = {self.q_name: Q, self.d_name: D, self.q_lens_name: Q_lens} yield X
class MeasureControl(Control): _view_name = Unicode('LeafletMeasureControlView').tag(sync=True) _model_name = Unicode('LeafletMeasureControlModel').tag(sync=True) _length_units = ['feet', 'meters', 'miles', 'kilometers'] _area_units = ['acres', 'hectares', 'sqfeet', 'sqmeters', 'sqmiles'] _custom_units_dict = {} _custom_units = Dict().tag(sync=True) primary_length_unit = Enum( values=_length_units, default_value='feet', help="""Possible values are feet, meters, miles, kilometers or any user defined unit""" ).tag(sync=True, o=True) secondary_length_unit = Enum( values=_length_units, default_value=None, allow_none=True, help="""Possible values are feet, meters, miles, kilometers or any user defined unit""" ).tag(sync=True, o=True) primary_area_unit = Enum( values=_area_units, default_value='acres', help="""Possible values are acres, hectares, sqfeet, sqmeters, sqmiles or any user defined unit""" ).tag(sync=True, o=True) secondary_area_unit = Enum( values=_area_units, default_value=None, allow_none=True, help="""Possible values are acres, hectares, sqfeet, sqmeters, sqmiles or any user defined unit""" ).tag(sync=True, o=True) active_color = Color('#ABE67E').tag(sync=True, o=True) completed_color = Color('#C8F2BE').tag(sync=True, o=True) popup_options = Dict({ 'className': 'leaflet-measure-resultpopup', 'autoPanPadding': [10, 10] }).tag(sync=True, o=True) capture_z_index = Int(10000).tag(sync=True, o=True) def add_length_unit(self, name, factor, decimals=0): self._length_units.append(name) self._add_unit(name, factor, decimals) def add_area_unit(self, name, factor, decimals=0): self._area_units.append(name) self._add_unit(name, factor, decimals) def _add_unit(self, name, factor, decimals): self._custom_units_dict[name] = { 'factor': factor, 'display': name, 'decimals': decimals } self._custom_units = dict(**self._custom_units_dict)
class Ripple(Renderer): # meta meta = RendererMeta('Ripples', 'Ripples of color when keys are pressed', 'Steve Kondik', '1.0') # configurable traits ripple_width = Int(default_value=DEFAULT_WIDTH, min=1, max=5).tag(config=True) speed = Int(default_value=DEFAULT_SPEED, min=1, max=9).tag(config=True) preset = ColorPresetTrait(ColorScheme, default_value=None).tag(config=True) random = Bool(True).tag(config=True) color = ColorTrait().tag(config=True) def __init__(self, *args, **kwargs): super(Ripple, self).__init__(*args, **kwargs) self._generator = ColorUtils.rainbow_generator() self._max_distance = None self.key_expire_time = DEFAULT_SPEED * EXPIRE_TIME_FACTOR self.fps = 30 @observe('speed') def _set_speed(self, change): self.key_expire_time = change.new * EXPIRE_TIME_FACTOR def _process_events(self, events): if self._generator is None: return None for event in events: if COLOR_KEY not in event.data: event.data[COLOR_KEY] = next(self._generator) @staticmethod def _ease(n): n = clamp(n, 0.0, 1.0) n = 2 * n if n < 1: return 0.5 * n**5 n = n - 2 return 0.5 * (n**5 + 2) def _draw_circles(self, layer, radius, event): width = self.ripple_width if COLOR_KEY not in event.data: return if event.coords is None or len(event.coords) == 0: self.logger.error('No coordinates available: %s', event) return if SCHEME_KEY in event.data: colors = event.data[SCHEME_KEY] else: color = event.data[COLOR_KEY] if width > 1: colors = ColorUtils.color_scheme(color=color, base_color=color, steps=width) else: colors = [color] event.data[SCHEME_KEY] = colors for circle_num in range(width - 1, -1, -1): if radius - circle_num < 0: continue rad = radius - circle_num a = Ripple._ease(1.0 - (rad / self._max_distance)) cc = (*colors[circle_num].rgb, colors[circle_num].alpha * a) for coord in event.coords: layer.ellipse(coord.y, coord.x, rad / 1.33, rad, color=cc) async def draw(self, layer, timestamp): """ Draw the next layer """ # Yield until the queue becomes active events = await self.get_input_events() if len(events) > 0: self._process_events(events) # paint circles in descending timestamp order (oldest first) events = sorted(events, key=operator.itemgetter(0), reverse=True) for event in events: distance = 1.0 - event.percent_complete if distance < 1.0: radius = self._max_distance * distance self._draw_circles(layer, radius, event) return True return False @observe('preset', 'color', 'background_color', 'random') def _update_colors(self, change=None): with self.hold_trait_notifications(): if change.new is None: return if change.name == 'preset': self.color = 'black' self.random = False self._generator = ColorUtils.color_generator(list(change.new.value)) elif change.name == 'random' and change.new: self.preset = None self.color = 'black' self._generator = ColorUtils.rainbow_generator() else: self.preset = None self.random = False base_color = self.background_color if base_color == (0, 0, 0, 1): base_color = None self._generator = ColorUtils.scheme_generator( color=self.color, base_color=base_color) def init(self, frame) -> bool: if not self.has_key_input: return False self._max_distance = math.hypot(frame.width, frame.height) return True
class Layer(Widget, InteractMixin): _view_name = Unicode('LeafletLayerView').tag(sync=True) _model_name = Unicode('LeafletLayerModel').tag(sync=True) _view_module = Unicode('jupyter-leaflet').tag(sync=True) _model_module = Unicode('jupyter-leaflet').tag(sync=True) _view_module_version = Unicode(EXTENSION_VERSION).tag(sync=True) _model_module_version = Unicode(EXTENSION_VERSION).tag(sync=True) name = Unicode('').tag(sync=True) base = Bool(False).tag(sync=True) bottom = Bool(False).tag(sync=True) popup = Instance(Widget, allow_none=True, default_value=None).tag(sync=True, **widget_serialization) popup_min_width = Int(50).tag(sync=True) popup_max_width = Int(300).tag(sync=True) popup_max_height = Int(default_value=None, allow_none=True).tag(sync=True) options = List(trait=Unicode).tag(sync=True) def __init__(self, **kwargs): super(Layer, self).__init__(**kwargs) self.on_msg(self._handle_mouse_events) @default('options') def _default_options(self): return [name for name in self.traits(o=True)] # Event handling _click_callbacks = Instance(CallbackDispatcher, ()) _dblclick_callbacks = Instance(CallbackDispatcher, ()) _mousedown_callbacks = Instance(CallbackDispatcher, ()) _mouseup_callbacks = Instance(CallbackDispatcher, ()) _mouseover_callbacks = Instance(CallbackDispatcher, ()) _mouseout_callbacks = Instance(CallbackDispatcher, ()) def _handle_mouse_events(self, _, content, buffers): event_type = content.get('type', '') if event_type == 'click': self._click_callbacks(**content) if event_type == 'dblclick': self._dblclick_callbacks(**content) if event_type == 'mousedown': self._mousedown_callbacks(**content) if event_type == 'mouseup': self._mouseup_callbacks(**content) if event_type == 'mouseover': self._mouseover_callbacks(**content) if event_type == 'mouseout': self._mouseout_callbacks(**content) def on_click(self, callback, remove=False): self._click_callbacks.register_callback(callback, remove=remove) def on_dblclick(self, callback, remove=False): self._dblclick_callbacks.register_callback(callback, remove=remove) def on_mousedown(self, callback, remove=False): self._mousedown_callbacks.register_callback(callback, remove=remove) def on_mouseup(self, callback, remove=False): self._mouseup_callbacks.register_callback(callback, remove=remove) def on_mouseover(self, callback, remove=False): self._mouseover_callbacks.register_callback(callback, remove=remove) def on_mouseout(self, callback, remove=False): self._mouseout_callbacks.register_callback(callback, remove=remove)
class BaseNN(Configurable): n_bins = Int( 11, help="number of kernels (including exact match)").tag(config=True) weight_size = Int(50, help="dimension of the first layer").tag(config=True) def __init__(self, **kwargs): super(BaseNN, self).__init__(**kwargs) @staticmethod def kernal_mus(n_kernels, use_exact): """ get the mu for each guassian kernel. Mu is the middle of each bin :param n_kernels: number of kernels (including exact match). first one is exact match :return: l_mu, a list of mu. """ if use_exact: l_mu = [1] else: l_mu = [2] if n_kernels == 1: return l_mu bin_size = 2.0 / (n_kernels - 1) # score range from [-1, 1] l_mu.append(1 - bin_size / 2) # mu: middle of the bin for i in np.arange(1, n_kernels - 1): l_mu.append(l_mu[i] - bin_size) return l_mu @staticmethod def kernel_sigmas(n_kernels, lamb, use_exact): """ get sigmas for each guassian kernel. :param n_kernels: number of kernels (including exactmath.) :param lamb: :param use_exact: :return: l_sigma, a list of simga """ bin_size = 2.0 / (n_kernels - 1) l_sigma = [0.00001] # for exact match. small variance -> exact match if n_kernels == 1: return l_sigma l_sigma += [bin_size * lamb] * (n_kernels - 1) return l_sigma @staticmethod def weight_variable(shape, name): tmp = np.sqrt(6.0) / np.sqrt(shape[0] + shape[1]) initial = tf.random_uniform(shape, minval=-tmp, maxval=tmp) return tf.Variable(initial, name=name) @staticmethod def re_pad(D, batch_size): D = np.array(D) D[D < 0] = 0 if len(D) < batch_size: tmp = np.zeros((batch_size - len(D), D.shape[1])) D = np.concatenate((D, tmp), axis=0) return D def gen_mask(self, Q, D, use_exact=True): """ Generate mask for the batch. Mask padding and OOV terms. Exact matches is alos masked if use_exat == False. :param Q: a batch of queries, [batch_size, max_len_q] :param D: a bacth of documents, [batch_size, max_len_d] :param use_exact: mask exact matches if set False. :return: a mask of shape [batch_size, max_len_q, max_len_d]. """ mask = np.zeros((self.batch_size, self.max_q_len, self.max_d_len)) for b in range(len(Q)): for q in range(len(Q[b])): if Q[b, q] > 0: mask[b, q, D[b] > 0] = 1 if not use_exact: mask[b, q, D[b] == Q[b, q]] = 0 return mask
class Map(DOMWidget, InteractMixin): _view_name = Unicode('LeafletMapView').tag(sync=True) _model_name = Unicode('LeafletMapModel').tag(sync=True) _view_module = Unicode('jupyter-leaflet').tag(sync=True) _model_module = Unicode('jupyter-leaflet').tag(sync=True) _view_module_version = Unicode(EXTENSION_VERSION).tag(sync=True) _model_module_version = Unicode(EXTENSION_VERSION).tag(sync=True) # Map options center = List(def_loc).tag(sync=True, o=True) zoom_start = Int(12).tag(sync=True, o=True) zoom = Int(12).tag(sync=True, o=True) max_zoom = Int(18).tag(sync=True, o=True) min_zoom = Int(1).tag(sync=True, o=True) interpolation = Unicode('bilinear').tag(sync=True, o=True) crs = Enum(values=allowed_crs, default_value='EPSG3857').tag(sync=True) # Specification of the basemap basemap = Dict(default_value=dict( url='https://{s}.tile.openstreetmap.org/{z}/{x}/{y}.png', max_zoom=19, attribution='Map data (c) <a href="https://openstreetmap.org">OpenStreetMap</a> contributors' )).tag(sync=True, o=True) modisdate = Unicode('yesterday').tag(sync=True) # Interaction options dragging = Bool(True).tag(sync=True, o=True) touch_zoom = Bool(True).tag(sync=True, o=True) scroll_wheel_zoom = Bool(False).tag(sync=True, o=True) double_click_zoom = Bool(True).tag(sync=True, o=True) box_zoom = Bool(True).tag(sync=True, o=True) tap = Bool(True).tag(sync=True, o=True) tap_tolerance = Int(15).tag(sync=True, o=True) world_copy_jump = Bool(False).tag(sync=True, o=True) close_popup_on_click = Bool(True).tag(sync=True, o=True) bounce_at_zoom_limits = Bool(True).tag(sync=True, o=True) keyboard = Bool(True).tag(sync=True, o=True) keyboard_pan_offset = Int(80).tag(sync=True, o=True) keyboard_zoom_offset = Int(1).tag(sync=True, o=True) inertia = Bool(True).tag(sync=True, o=True) inertia_deceleration = Int(3000).tag(sync=True, o=True) inertia_max_speed = Int(1500).tag(sync=True, o=True) # inertia_threshold = Int(?, o=True).tag(sync=True) # fade_animation = Bool(?).tag(sync=True, o=True) # zoom_animation = Bool(?).tag(sync=True, o=True) zoom_animation_threshold = Int(4).tag(sync=True, o=True) # marker_zoom_animation = Bool(?).tag(sync=True, o=True) fullscreen = Bool(False).tag(sync=True, o=True) options = List(trait=Unicode).tag(sync=True) style = InstanceDict(MapStyle).tag(sync=True, **widget_serialization) default_style = InstanceDict(MapStyle).tag(sync=True, **widget_serialization) dragging_style = InstanceDict(MapStyle).tag(sync=True, **widget_serialization) zoom_control = Bool(True) zoom_control_instance = ZoomControl() attribution_control = Bool(True) attribution_control_instance = AttributionControl(position='bottomright') @default('dragging_style') def _default_dragging_style(self): return {'cursor': 'move'} @default('options') def _default_options(self): return [name for name in self.traits(o=True)] south = Float(def_loc[0], read_only=True).tag(sync=True) north = Float(def_loc[0], read_only=True).tag(sync=True) east = Float(def_loc[1], read_only=True).tag(sync=True) west = Float(def_loc[1], read_only=True).tag(sync=True) layers = Tuple(trait=Instance(Layer)).tag(sync=True, **widget_serialization) @default('layers') def _default_layers(self): return (basemap_to_tiles(self.basemap, self.modisdate, base=True),) bounds = Tuple(read_only=True) bounds_polygon = Tuple(read_only=True) @observe('south', 'north', 'east', 'west') def _observe_bounds(self, change): self.set_trait('bounds', ((self.south, self.west), (self.north, self.east))) self.set_trait('bounds_polygon', ((self.north, self.west), (self.north, self.east), (self.south, self.east), (self.south, self.west))) def __init__(self, **kwargs): super(Map, self).__init__(**kwargs) self.on_displayed(self._fire_children_displayed) self.on_msg(self._handle_leaflet_event) if self.zoom_control: self.add_control(self.zoom_control_instance) if self.attribution_control: self.add_control(self.attribution_control_instance) @observe('zoom_control') def observe_zoom_control(self, change): if change['new']: self.add_control(self.zoom_control_instance) else: if self.zoom_control_instance in self.controls: self.remove_control(self.zoom_control_instance) @observe('attribution_control') def observe_attribution_control(self, change): if change['new']: self.add_control(self.attribution_control_instance) else: if self.attribution_control_instance in self.controls: self.remove_control(self.attribution_control_instance) def _fire_children_displayed(self, widget, **kwargs): for layer in self.layers: layer._handle_displayed(**kwargs) for control in self.controls: control._handle_displayed(**kwargs) _layer_ids = List() @validate('layers') def _validate_layers(self, proposal): '''Validate layers list. Makes sure only one instance of any given layer can exist in the layers list. ''' self._layer_ids = [l.model_id for l in proposal.value] if len(set(self._layer_ids)) != len(self._layer_ids): raise LayerException('duplicate layer detected, only use each layer once') return proposal.value def add_layer(self, layer): if isinstance(layer, dict): layer = basemap_to_tiles(layer) if layer.model_id in self._layer_ids: raise LayerException('layer already on map: %r' % layer) self.layers = tuple([l for l in self.layers] + [layer]) def remove_layer(self, layer): if layer.model_id not in self._layer_ids: raise LayerException('layer not on map: %r' % layer) self.layers = tuple([l for l in self.layers if l.model_id != layer.model_id]) def substitute_layer(self, old, new): if isinstance(new, dict): new = basemap_to_tiles(new) if old.model_id not in self._layer_ids: raise LayerException('Could not substitute layer: layer not on map.') self.layers = tuple([new if l.model_id == old.model_id else l for l in self.layers]) def clear_layers(self): self.layers = () controls = Tuple(trait=Instance(Control)).tag(sync=True, **widget_serialization) _control_ids = List() @validate('controls') def _validate_controls(self, proposal): '''Validate controls list. Makes sure only one instance of any given layer can exist in the controls list. ''' self._control_ids = [c.model_id for c in proposal.value] if len(set(self._control_ids)) != len(self._control_ids): raise ControlException('duplicate control detected, only use each control once') return proposal.value def add_control(self, control): if control.model_id in self._control_ids: raise ControlException('control already on map: %r' % control) self.controls = tuple([c for c in self.controls] + [control]) def remove_control(self, control): if control.model_id not in self._control_ids: raise ControlException('control not on map: %r' % control) self.controls = tuple([c for c in self.controls if c.model_id != control.model_id]) def clear_controls(self): self.controls = () def __iadd__(self, item): if isinstance(item, Layer): self.add_layer(item) elif isinstance(item, Control): self.add_control(item) return self def __isub__(self, item): if isinstance(item, Layer): self.remove_layer(item) elif isinstance(item, Control): self.remove_control(item) return self def __add__(self, item): if isinstance(item, Layer): self.add_layer(item) elif isinstance(item, Control): self.add_control(item) return self # Event handling _interaction_callbacks = Instance(CallbackDispatcher, ()) def _handle_leaflet_event(self, _, content, buffers): if content.get('event', '') == 'interaction': self._interaction_callbacks(**content) def on_interaction(self, callback, remove=False): self._interaction_callbacks.register_callback(callback, remove=remove)
class Main(Experiment): description = Unicode(u"Calculate precision-recall accuracy of trained coco model.") # # Run setup # batch_size = Int(256, config=True, help="Batch size. default: 256") num_workers = Int(8, config=True, help="Number of workers to use for data loading. default: 8") device = Unicode("cuda", config=True, help="Use `cuda` backend. default: cuda") # # Hyper parameters. # unseen = Bool(False, config=True, help="Test on unseen classes.") skip_tests = Int(1, config=True, help="How many test pairs to skip? for better runtime. default: 1") debug_size = Int(-1, config=True, help="Reduce dataset sizes. This is useful when developing the script. default -1") # # Resume previous run parameters. # resume_path = Unicode(u"/dccstor/alfassy/finalLaSO/code_release/paperModels", config=True, help="Resume from checkpoint file (requires using also '--resume_epoch'.") resume_epoch = Int(0, config=True, help="Epoch to resume (requires using also '--resume_path'.") coco_path = Unicode(u"/tmp/aa/coco", config=True, help="path to local coco dataset path") init_inception = Bool(True, config=True, help="Initialize the inception networks using the paper's base network.") # # Network hyper parameters # base_network_name = Unicode("Inception3", config=True, help="Name of base network to use.") avgpool_kernel = Int(10, config=True, help="Size of the last avgpool layer in the Resnet. Should match the cropsize.") classifier_name = Unicode("Inception3Classifier", config=True, help="Name of classifier to use.") sets_network_name = Unicode("SetOpsResModule", config=True, help="Name of setops module to use.") sets_block_name = Unicode("SetopResBlock_v1", config=True, help="Name of setops network to use.") sets_basic_block_name = Unicode("SetopResBasicBlock", config=True, help="Name of the basic setops block to use (where applicable).") ops_layer_num = Int(1, config=True, help="Ops Module layers num.") ops_latent_dim = Int(1024, config=True, help="Ops Module inner latent dim.") setops_dropout = Float(0, config=True, help="Dropout ratio of setops module.") crop_size = Int(299, config=True, help="Size of input crop (Resnet 224, inception 299).") scale_size = Int(350, config=True, help="Size of input scale for data augmentation. default: 350") paper_reproduce = Bool(False, config=True, help="Use paper reproduction settings. default: False") discriminator_name = Unicode("AmitDiscriminator", config=True, help="Name of discriminator (unseen classifier) to use. default: AmitDiscriminator") embedding_dim = Int(2048, config=True, help="Dimensionality of the LaSO space. default:2048") classifier_latent_dim = Int(2048, config=True, help="Dimensionality of the classifier latent space. default:2048") def run(self): # # Setup the model # base_model, classifier, setops_model = self.setup_model() base_model.to(self.device) classifier.to(self.device) setops_model.to(self.device) base_model.eval() classifier.eval() setops_model.eval() # # Load the dataset # pair_dataset, pair_loader, pair_dataset_sub, pair_loader_sub = self.setup_datasets() logging.info("Calcualting classifications:") output_a_list, output_b_list, fake_a_list, fake_b_list, target_a_list, target_b_list = [], [], [], [], [], [] a_S_b_list, b_S_a_list, a_U_b_list, b_U_a_list, a_I_b_list, b_I_a_list = [], [], [], [], [], [] target_a_I_b_list, target_a_U_b_list, target_a_S_b_list, target_b_S_a_list = [], [], [], [] with torch.no_grad(): for batch in tqdm(pair_loader): input_a, input_b, target_a, target_b = _prepare_batch(batch, device=self.device) # # Apply the classification model # embed_a = base_model(input_a).view(input_a.size(0), -1) embed_b = base_model(input_b).view(input_b.size(0), -1) output_a = classifier(embed_a) output_b = classifier(embed_b) # # Apply the setops model. # outputs_setopt = setops_model(embed_a, embed_b) a_S_b, b_S_a, a_U_b, b_U_a, a_I_b, b_I_a = \ [classifier(o) for o in outputs_setopt[2:8]] output_a_list.append(output_a.cpu().numpy()) output_b_list.append(output_b.cpu().numpy()) # fake_a_list.append(fake_a.cpu().numpy()) # fake_b_list.append(fake_b.cpu().numpy()) a_S_b_list.append(a_S_b.cpu().numpy()) b_S_a_list.append(b_S_a.cpu().numpy()) a_U_b_list.append(a_U_b.cpu().numpy()) b_U_a_list.append(b_U_a.cpu().numpy()) a_I_b_list.append(a_I_b.cpu().numpy()) b_I_a_list.append(b_I_a.cpu().numpy()) # # Calculate the target setops operations # target_a_list.append(target_a.cpu().numpy()) target_b_list.append(target_b.cpu().numpy()) target_a = target_a.type(torch.cuda.ByteTensor) target_b = target_b.type(torch.cuda.ByteTensor) target_a_I_b = target_a & target_b target_a_U_b = target_a | target_b target_a_S_b = target_a & ~target_a_I_b target_b_S_a = target_b & ~target_a_I_b target_a_I_b_list.append(target_a_I_b.type(torch.cuda.FloatTensor).cpu().numpy()) target_a_U_b_list.append(target_a_U_b.type(torch.cuda.FloatTensor).cpu().numpy()) target_a_S_b_list.append(target_a_S_b.type(torch.cuda.FloatTensor).cpu().numpy()) target_b_S_a_list.append(target_b_S_a.type(torch.cuda.FloatTensor).cpu().numpy()) logging.info("Calculating classifications for subtraction independently:") a_S_b_list, b_S_a_list = [], [] target_a_S_b_list, target_b_S_a_list = [], [] with torch.no_grad(): for batch in tqdm(pair_loader_sub): input_a, input_b, target_a, target_b = _prepare_batch(batch, device=self.device) # # Apply the classification model # embed_a = base_model(input_a).view(input_a.size(0), -1) embed_b = base_model(input_b).view(input_b.size(0), -1) # # Apply the setops model. # outputs_setopt = setops_model(embed_a, embed_b) a_S_b, b_S_a, _, _, _, _ = \ [classifier(o) for o in outputs_setopt[2:8]] a_S_b_list.append(a_S_b.cpu().numpy()) b_S_a_list.append(b_S_a.cpu().numpy()) # # Calculate the target setops operations # target_a = target_a.type(torch.cuda.ByteTensor) target_b = target_b.type(torch.cuda.ByteTensor) target_a_I_b = target_a & target_b target_a_S_b = target_a & ~target_a_I_b target_b_S_a = target_b & ~target_a_I_b target_a_S_b_list.append(target_a_S_b.type(torch.cuda.FloatTensor).cpu().numpy()) target_b_S_a_list.append(target_b_S_a.type(torch.cuda.FloatTensor).cpu().numpy()) # # Output restuls # logging.info("Calculating precision:") for output, target, name in tqdm(zip( (output_a_list, output_b_list, a_S_b_list, b_S_a_list, a_U_b_list, b_U_a_list, a_I_b_list, b_I_a_list), (target_a_list, target_b_list, target_a_S_b_list, target_b_S_a_list, target_a_U_b_list, target_a_U_b_list, target_a_I_b_list, target_a_I_b_list), ("real_a", "real_b", "a_S_b", "b_S_a", "a_U_b", "b_U_a", "a_I_b", "b_I_a"))): output = np.concatenate(output, axis=0) target = np.concatenate(target, axis=0) ap = [average_precision_score(target[:, i], output[:, i]) for i in range(output.shape[1])] pr_graphs = [precision_recall_curve(target[:, i], output[:, i]) for i in range(output.shape[1])] ap_sum = 0 for label in pair_dataset.labels_list: ap_sum += ap[label] ap_avg = ap_sum / len(pair_dataset.labels_list) logging.info( 'Test {} average precision score, macro-averaged over all {} classes: {}'.format( name, len(pair_dataset.labels_list), ap_avg) ) with open(os.path.join(self.results_path, "{}_results.pkl".format(name)), "wb") as f: pickle.dump(dict(ap=ap, pr_graphs=pr_graphs), f) def setup_model(self): """Create or resume the models.""" logging.info("Setup the models.") logging.info("{} model".format(self.base_network_name)) models_path = Path(self.resume_path) if self.base_network_name.lower().startswith("resnet"): base_model, classifier = getattr(setops_models, self.base_network_name)( num_classes=80, avgpool_kernel=self.avgpool_kernel ) else: base_model = setops_models.Inception3(aux_logits=False, transform_input=True) classifier = getattr(setops_models, self.classifier_name)(num_classes=80) if self.init_inception: logging.info("Initialize inception model using paper's networks.") checkpoint = torch.load(models_path / 'paperBaseModel') base_model = setops_models.Inception3(aux_logits=False, transform_input=True) base_model.load_state_dict( {k: v for k, v in checkpoint["state_dict"].items() if k in base_model.state_dict()} ) classifier.load_state_dict( {k: v for k, v in checkpoint["state_dict"].items() if k in classifier.state_dict()} ) setops_model_cls = getattr(setops_models, self.sets_network_name) setops_model = setops_model_cls( input_dim=self.embedding_dim, S_latent_dim=self.ops_latent_dim, S_layers_num=self.ops_layer_num, I_latent_dim=self.ops_latent_dim, I_layers_num=self.ops_layer_num, U_latent_dim=self.ops_latent_dim, U_layers_num=self.ops_layer_num, block_cls_name=self.sets_block_name, basic_block_cls_name=self.sets_basic_block_name, dropout_ratio=self.setops_dropout, ) if self.unseen: # # In the unseen mode, we have to load the trained discriminator. # discriminator_cls = getattr(setops_models, self.discriminator_name) classifier = discriminator_cls( input_dim=self.embedding_dim, latent_dim=self.classifier_latent_dim ) if not self.resume_path: raise FileNotFoundError("resume_path is compulsory in test_precision") logging.info("Resuming the models.") if not self.init_inception: base_model.load_state_dict( torch.load(sorted(models_path.glob("networks_base_model_{}*.pth".format(self.resume_epoch)))[-1]) ) if self.paper_reproduce: logging.info("using paper models") setops_model_cls = getattr(setops_models, "SetOpsModulePaper") setops_model = setops_model_cls(models_path) if self.unseen: checkpoint = torch.load(models_path / 'paperDiscriminator') classifier.load_state_dict(checkpoint['state_dict']) else: setops_model.load_state_dict( torch.load( sorted( models_path.glob("networks_setops_model_{}*.pth".format(self.resume_epoch)) )[-1] ) ) if self.unseen: classifier.load_state_dict( torch.load(sorted(models_path.glob("networks_discriminator_{}*.pth".format(self.resume_epoch)))[-1]) ) elif not self.init_inception: classifier.load_state_dict( torch.load(sorted(models_path.glob("networks_classifier_{}*.pth".format(self.resume_epoch)))[-1]) ) return base_model, classifier, setops_model def setup_datasets(self): """Load the training datasets.""" # TODO: comment out if you don't want to copy coco to /tmp/aa # copy_coco_data() logging.info("Setting up the datasets.") CocoDatasetPairs = getattr(alfassy, "CocoDatasetPairs") CocoDatasetPairsSub = getattr(alfassy, "CocoDatasetPairsSub") if self.paper_reproduce: logging.info("Setting up the datasets and augmentation for paper reproduction") scaler = transforms.Scale((350, 350)) else: scaler = transforms.Resize(self.crop_size) val_transform = transforms.Compose( [ scaler, transforms.CenterCrop(self.crop_size), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) ] ) pair_dataset = CocoDatasetPairs( root_dir=self.coco_path, set_name='val2014', unseen_set=self.unseen, transform=val_transform, debug_size=self.debug_size ) pair_loader = DataLoader( pair_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers ) pair_dataset_sub = CocoDatasetPairsSub( root_dir=self.coco_path, set_name='val2014', unseen_set=self.unseen, transform=val_transform, debug_size=self.debug_size ) pair_loader_sub = DataLoader( pair_dataset_sub, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers ) return pair_dataset, pair_loader, pair_dataset_sub, pair_loader_sub
class FargateSpawner(Spawner): aws_region = Unicode(config=True) aws_ecs_host = Unicode(config=True) task_role_arn = Unicode(config=True) task_cluster_name = Unicode(config=True) task_container_name = Unicode(config=True) task_definition_arn = Unicode(config=True) task_security_groups = List(trait=Unicode, config=True) task_subnets = List(trait=Unicode, config=True) task_assign_public_ip = Enum(["DISABLED", "ENABLED"], "DISABLED", config=True) task_platform_version = Unicode("LATEST", config=True) notebook_port = Int(config=True) notebook_scheme = Unicode(config=True) notebook_args = List(trait=Unicode, config=True) authentication_class = Type(FargateSpawnerAuthentication, config=True) authentication = Instance(FargateSpawnerAuthentication) @default('authentication') def _default_authentication(self): return self.authentication_class(parent=self) task_arn = Unicode('') # We mostly are able to call the AWS API to determine status. However, when we yield the # event loop to create the task, if there is a poll before the creation is complete, # we must behave as though we are running/starting, but we have no IDs to use with which # to check the task. calling_run_task = Bool(False) progress_buffer = None def load_state(self, state): ''' Misleading name: this "loads" the state onto self, to be used by other methods ''' super().load_state(state) # Called when first created: we might have no state from a previous invocation self.task_arn = state.get('task_arn', '') def get_state(self): ''' Misleading name: the return value of get_state is saved to the database in order to be able to restore after the hub went down ''' state = super().get_state() state['task_arn'] = self.task_arn return state async def poll(self): # Return values, as dictacted by the Jupyterhub framework: # 0 == not running, or not starting up, i.e. we need to call start # None == running, or not finished starting # 1, or anything else == error return \ None if self.calling_run_task else \ 0 if self.task_arn == '' else \ None if (await _get_task_status(self.log, self._aws_endpoint(), self.task_cluster_name, self.task_arn)) in ALLOWED_STATUSES else \ 1 async def start(self): self.log.debug('Starting spawner') task_port = self.notebook_port self.progress_buffer.write({'progress': 0.5, 'message': 'Starting server...'}) try: self.calling_run_task = True debug_args = ['--debug'] if self.debug else [] args = debug_args + ['--port=' + str(task_port)] + self.notebook_args run_response = await _run_task( self.log, self._aws_endpoint(), self.task_role_arn, self.task_cluster_name, self.task_container_name, self.task_definition_arn, self.task_security_groups, self.task_subnets, self.task_assign_public_ip, self.task_platform_version, self.cmd + args, self.get_env(), self.user_options) task_arn = run_response['tasks'][0]['taskArn'] self.progress_buffer.write({'progress': 1}) finally: self.calling_run_task = False self.task_arn = task_arn max_polls = 50 num_polls = 0 task_ip = '' while task_ip == '': num_polls += 1 if num_polls >= max_polls: raise Exception('Task {} took too long to find IP address'.format(self.task_arn)) task_ip = await _get_task_ip(self.log, self._aws_endpoint(), self.task_cluster_name, task_arn) await gen.sleep(1) self.progress_buffer.write({'progress': 1 + num_polls / max_polls}) self.progress_buffer.write({'progress': 2}) max_polls = self.start_timeout num_polls = 0 status = '' while status != 'RUNNING': num_polls += 1 if num_polls >= max_polls: raise Exception('Task {} took too long to become running'.format(self.task_arn)) status = await _get_task_status(self.log, self._aws_endpoint(), self.task_cluster_name, task_arn) if status not in ALLOWED_STATUSES: raise Exception('Task {} is {}'.format(self.task_arn, status)) await gen.sleep(1) self.progress_buffer.write({'progress': 2 + num_polls / max_polls * 98}) self.progress_buffer.write({'progress': 100, 'message': 'Server started'}) await gen.sleep(1) self.progress_buffer.close() return f'{self.notebook_scheme}://{task_ip}:{task_port}' async def stop(self, now=False): if self.task_arn == '': return self.log.debug('Stopping task (%s)...', self.task_arn) await _ensure_stopped_task(self.log, self._aws_endpoint(), self.task_cluster_name, self.task_arn) self.log.debug('Stopped task (%s)... (done)', self.task_arn) def clear_state(self): super().clear_state() self.log.debug('Clearing state: (%s)', self.task_arn) self.task_arn = '' self.progress_buffer = AsyncIteratorBuffer() async def progress(self): async for progress_message in self.progress_buffer: yield progress_message def _aws_endpoint(self): return { 'region': self.aws_region, 'ecs_host': self.aws_ecs_host, 'ecs_auth': self.authentication.get_credentials, }
class Widget(LoggingHasTraits): #------------------------------------------------------------------------- # Class attributes #------------------------------------------------------------------------- _widget_construction_callback = None # widgets is a dictionary of all active widget objects widgets = {} # widget_types is a registry of widgets by module, version, and name: widget_types = WidgetRegistry() @classmethod def close_all(cls): for widget in list(cls.widgets.values()): widget.close() @staticmethod def on_widget_constructed(callback): """Registers a callback to be called when a widget is constructed. The callback must have the following signature: callback(widget)""" Widget._widget_construction_callback = callback @staticmethod def _call_widget_constructed(widget): """Static method, called when a widget is constructed.""" if Widget._widget_construction_callback is not None and callable(Widget._widget_construction_callback): Widget._widget_construction_callback(widget) @staticmethod def handle_comm_opened(comm, msg): """Static method, called when a widget is constructed.""" version = msg.get('metadata', {}).get('version', '') if version.split('.')[0] != PROTOCOL_VERSION_MAJOR: raise ValueError("Incompatible widget protocol versions: received version %r, expected version %r"%(version, __protocol_version__)) data = msg['content']['data'] state = data['state'] # Find the widget class to instantiate in the registered widgets widget_class = Widget.widget_types.get(state['_model_module'], state['_model_module_version'], state['_model_name'], state['_view_module'], state['_view_module_version'], state['_view_name']) widget = widget_class(comm=comm) if 'buffer_paths' in data: _put_buffers(state, data['buffer_paths'], msg['buffers']) widget.set_state(state) @staticmethod def get_manager_state(drop_defaults=False, widgets=None): """Returns the full state for a widget manager for embedding :param drop_defaults: when True, it will not include default value :param widgets: list with widgets to include in the state (or all widgets when None) :return: """ state = {} if widgets is None: widgets = Widget.widgets.values() for widget in widgets: state[widget.model_id] = widget._get_embed_state(drop_defaults=drop_defaults) return {'version_major': 2, 'version_minor': 0, 'state': state} def _get_embed_state(self, drop_defaults=False): state = { 'model_name': self._model_name, 'model_module': self._model_module, 'model_module_version': self._model_module_version } model_state, buffer_paths, buffers = _remove_buffers(self.get_state(drop_defaults=drop_defaults)) state['state'] = model_state if len(buffers) > 0: state['buffers'] = [{'encoding': 'base64', 'path': p, 'data': standard_b64encode(d).decode('ascii')} for p, d in zip(buffer_paths, buffers)] return state def get_view_spec(self): return dict(version_major=2, version_minor=0, model_id=self._model_id) #------------------------------------------------------------------------- # Traits #------------------------------------------------------------------------- _model_name = Unicode('WidgetModel', help="Name of the model.", read_only=True).tag(sync=True) _model_module = Unicode('@jupyter-widgets/base', help="The namespace for the model.", read_only=True).tag(sync=True) _model_module_version = Unicode(__jupyter_widgets_base_version__, help="A semver requirement for namespace version containing the model.", read_only=True).tag(sync=True) _view_name = Unicode(None, allow_none=True, help="Name of the view.").tag(sync=True) _view_module = Unicode(None, allow_none=True, help="The namespace for the view.").tag(sync=True) _view_module_version = Unicode('', help="A semver requirement for the namespace version containing the view.").tag(sync=True) _view_count = Int(None, allow_none=True, help="EXPERIMENTAL: The number of views of the model displayed in the frontend. This attribute is experimental and may change or be removed in the future. None signifies that views will not be tracked. Set this to 0 to start tracking view creation/deletion.").tag(sync=True) comm = Instance('ipykernel.comm.Comm', allow_none=True) keys = List(help="The traits which are synced.") @default('keys') def _default_keys(self): return [name for name in self.traits(sync=True)] _property_lock = Dict() _holding_sync = False _states_to_send = Set() _display_callbacks = Instance(CallbackDispatcher, ()) _msg_callbacks = Instance(CallbackDispatcher, ()) #------------------------------------------------------------------------- # (Con/de)structor #------------------------------------------------------------------------- def __init__(self, **kwargs): """Public constructor""" self._model_id = kwargs.pop('model_id', None) super(Widget, self).__init__(**kwargs) Widget._call_widget_constructed(self) self.open() def __del__(self): """Object disposal""" self.close() #------------------------------------------------------------------------- # Properties #------------------------------------------------------------------------- def open(self): """Open a comm to the frontend if one isn't already open.""" if self.comm is None: state, buffer_paths, buffers = _remove_buffers(self.get_state()) args = dict(target_name='jupyter.widget', data={'state': state, 'buffer_paths': buffer_paths}, buffers=buffers, metadata={'version': __protocol_version__} ) if self._model_id is not None: args['comm_id'] = self._model_id self.comm = Comm(**args) @observe('comm') def _comm_changed(self, change): """Called when the comm is changed.""" if change['new'] is None: return self._model_id = self.model_id self.comm.on_msg(self._handle_msg) Widget.widgets[self.model_id] = self @property def model_id(self): """Gets the model id of this widget. If a Comm doesn't exist yet, a Comm will be created automagically.""" return self.comm.comm_id #------------------------------------------------------------------------- # Methods #------------------------------------------------------------------------- def close(self): """Close method. Closes the underlying comm. When the comm is closed, all of the widget views are automatically removed from the front-end.""" if self.comm is not None: Widget.widgets.pop(self.model_id, None) self.comm.close() self.comm = None self._ipython_display_ = None def send_state(self, key=None): """Sends the widget state, or a piece of it, to the front-end, if it exists. Parameters ---------- key : unicode, or iterable (optional) A single property's name or iterable of property names to sync with the front-end. """ state = self.get_state(key=key) if len(state) > 0: state, buffer_paths, buffers = _remove_buffers(state) msg = {'method': 'update', 'state': state, 'buffer_paths': buffer_paths} self._send(msg, buffers=buffers) def get_state(self, key=None, drop_defaults=False): """Gets the widget state, or a piece of it. Parameters ---------- key : unicode or iterable (optional) A single property's name or iterable of property names to get. Returns ------- state : dict of states metadata : dict metadata for each field: {key: metadata} """ if key is None: keys = self.keys elif isinstance(key, string_types): keys = [key] elif isinstance(key, collections.Iterable): keys = key else: raise ValueError("key must be a string, an iterable of keys, or None") state = {} traits = self.traits() for k in keys: to_json = self.trait_metadata(k, 'to_json', self._trait_to_json) value = to_json(getattr(self, k), self) if not PY3 and isinstance(traits[k], Bytes) and isinstance(value, bytes): value = memoryview(value) if not drop_defaults or not self._compare(value, traits[k].default_value): state[k] = value return state def _is_numpy(self, x): return x.__class__.__name__ == 'ndarray' and x.__class__.__module__ == 'numpy' def _compare(self, a, b): if self._is_numpy(a) or self._is_numpy(b): import numpy as np return np.array_equal(a, b) else: return a == b def set_state(self, sync_data): """Called when a state is received from the front-end.""" # The order of these context managers is important. Properties must # be locked when the hold_trait_notification context manager is # released and notifications are fired. with self._lock_property(**sync_data), self.hold_trait_notifications(): for name in sync_data: if name in self.keys: from_json = self.trait_metadata(name, 'from_json', self._trait_from_json) self.set_trait(name, from_json(sync_data[name], self)) def send(self, content, buffers=None): """Sends a custom msg to the widget model in the front-end. Parameters ---------- content : dict Content of the message to send. buffers : list of binary buffers Binary buffers to send with message """ self._send({"method": "custom", "content": content}, buffers=buffers) def on_msg(self, callback, remove=False): """(Un)Register a custom msg receive callback. Parameters ---------- callback: callable callback will be passed three arguments when a message arrives:: callback(widget, content, buffers) remove: bool True if the callback should be unregistered.""" self._msg_callbacks.register_callback(callback, remove=remove) def on_displayed(self, callback, remove=False): """(Un)Register a widget displayed callback. Parameters ---------- callback: method handler Must have a signature of:: callback(widget, **kwargs) kwargs from display are passed through without modification. remove: bool True if the callback should be unregistered.""" self._display_callbacks.register_callback(callback, remove=remove) def add_traits(self, **traits): """Dynamically add trait attributes to the Widget.""" super(Widget, self).add_traits(**traits) for name, trait in traits.items(): if trait.get_metadata('sync'): self.keys.append(name) self.send_state(name) def notify_change(self, change): """Called when a property has changed.""" # Send the state to the frontend before the user-registered callbacks # are called. name = change['name'] if self.comm is not None and self.comm.kernel is not None: # Make sure this isn't information that the front-end just sent us. if name in self.keys and self._should_send_property(name, getattr(self, name)): # Send new state to front-end self.send_state(key=name) super(Widget, self).notify_change(change) def __repr__(self): return self._gen_repr_from_keys(self._repr_keys()) #------------------------------------------------------------------------- # Support methods #------------------------------------------------------------------------- @contextmanager def _lock_property(self, **properties): """Lock a property-value pair. The value should be the JSON state of the property. NOTE: This, in addition to the single lock for all state changes, is flawed. In the future we may want to look into buffering state changes back to the front-end.""" self._property_lock = properties try: yield finally: self._property_lock = {} @contextmanager def hold_sync(self): """Hold syncing any state until the outermost context manager exits""" if self._holding_sync is True: yield else: try: self._holding_sync = True yield finally: self._holding_sync = False self.send_state(self._states_to_send) self._states_to_send.clear() def _should_send_property(self, key, value): """Check the property lock (property_lock)""" to_json = self.trait_metadata(key, 'to_json', self._trait_to_json) if key in self._property_lock: # model_state, buffer_paths, buffers split_value = _remove_buffers({ key: to_json(value, self)}) split_lock = _remove_buffers({ key: self._property_lock[key]}) # A roundtrip conversion through json in the comparison takes care of # idiosyncracies of how python data structures map to json, for example # tuples get converted to lists. if (jsonloads(jsondumps(split_value[0])) == split_lock[0] and split_value[1] == split_lock[1] and _buffer_list_equal(split_value[2], split_lock[2])): return False if self._holding_sync: self._states_to_send.add(key) return False else: return True # Event handlers @_show_traceback def _handle_msg(self, msg): """Called when a msg is received from the front-end""" data = msg['content']['data'] method = data['method'] if method == 'update': if 'state' in data: state = data['state'] if 'buffer_paths' in data: _put_buffers(state, data['buffer_paths'], msg['buffers']) self.set_state(state) # Handle a state request. elif method == 'request_state': self.send_state() # Handle a custom msg from the front-end. elif method == 'custom': if 'content' in data: self._handle_custom_msg(data['content'], msg['buffers']) # Catch remainder. else: self.log.error('Unknown front-end to back-end widget msg with method "%s"' % method) def _handle_custom_msg(self, content, buffers): """Called when a custom msg is received.""" self._msg_callbacks(self, content, buffers) def _handle_displayed(self, **kwargs): """Called when a view has been displayed for this widget instance""" self._display_callbacks(self, **kwargs) @staticmethod def _trait_to_json(x, self): """Convert a trait value to json.""" return x @staticmethod def _trait_from_json(x, self): """Convert json values to objects.""" return x def _ipython_display_(self, **kwargs): """Called when `IPython.display.display` is called on the widget.""" if self._view_name is not None: # The 'application/vnd.jupyter.widget-view+json' mimetype has not been registered yet. # See the registration process and naming convention at # http://tools.ietf.org/html/rfc6838 # and the currently registered mimetypes at # http://www.iana.org/assignments/media-types/media-types.xhtml. data = { 'text/plain': repr(self), 'text/html': self._fallback_html(), 'application/vnd.jupyter.widget-view+json': { 'version_major': 2, 'version_minor': 0, 'model_id': self._model_id } } display(data, raw=True) self._handle_displayed(**kwargs) def _send(self, msg, buffers=None): """Sends a message to the model in the front-end.""" if self.comm is not None and self.comm.kernel is not None: self.comm.send(data=msg, buffers=buffers) def _repr_keys(self): traits = self.traits() for key in sorted(self.keys): # Exclude traits that start with an underscore if key[0] == '_': continue # Exclude traits who are equal to their default value value = getattr(self, key) trait = traits[key] if self._compare(value, trait.default_value): continue elif (isinstance(trait, (Container, Dict)) and trait.default_value == Undefined and (value is None or len(value) == 0)): # Empty container, and dynamic default will be empty continue yield key def _gen_repr_from_keys(self, keys): class_name = self.__class__.__name__ signature = ', '.join( '%s=%r' % (key, getattr(self, key)) for key in keys ) return '%s(%s)' % (class_name, signature) def _fallback_html(self): return _FALLBACK_HTML_TEMPLATE.format(widget_type=type(self).__name__)
class Widget(LoggingConfigurable): #------------------------------------------------------------------------- # Class attributes #------------------------------------------------------------------------- _widget_construction_callback = None widgets = {} widget_types = {} @staticmethod def on_widget_constructed(callback): """Registers a callback to be called when a widget is constructed. The callback must have the following signature: callback(widget)""" Widget._widget_construction_callback = callback @staticmethod def _call_widget_constructed(widget): """Static method, called when a widget is constructed.""" if Widget._widget_construction_callback is not None and callable( Widget._widget_construction_callback): Widget._widget_construction_callback(widget) @staticmethod def handle_comm_opened(comm, msg): """Static method, called when a widget is constructed.""" class_name = str(msg['content']['data']['widget_class']) if class_name in Widget.widget_types: widget_class = Widget.widget_types[class_name] else: widget_class = import_item(class_name) widget = widget_class(comm=comm) @staticmethod def get_manager_state(drop_defaults=False): return dict( version_major=1, version_minor=0, state={ k: { 'model_name': Widget.widgets[k]._model_name, 'model_module': Widget.widgets[k]._model_module, 'model_module_version': Widget.widgets[k]._model_module_version, 'state': Widget.widgets[k].get_state(drop_defaults=drop_defaults) } for k in Widget.widgets }) def get_view_spec(self): return dict(version_major=1, version_minor=0, model_id=self._model_id) #------------------------------------------------------------------------- # Traits #------------------------------------------------------------------------- _model_module = Unicode( 'jupyter-js-widgets', help="A JavaScript module name in which to find _model_name.").tag( sync=True) _model_name = Unicode( 'WidgetModel', help="Name of the model object in the front-end.").tag(sync=True) _model_module_version = Unicode( '*', help="A semver requirement for the model module version.").tag( sync=True) _view_module = Unicode( None, allow_none=True, help="A JavaScript module in which to find _view_name.").tag(sync=True) _view_name = Unicode(None, allow_none=True, help="Name of the view object.").tag(sync=True) _view_module_version = Unicode( '*', help="A semver requirement for the view module version.").tag( sync=True) comm = Instance('ipykernel.comm.Comm', allow_none=True) msg_throttle = Int( 1, help= """Maximum number of msgs the front-end can send before receiving an idle msg from the back-end.""" ).tag(sync=True) keys = List() def _keys_default(self): return [name for name in self.traits(sync=True)] _property_lock = Dict() _holding_sync = False _states_to_send = Set() _display_callbacks = Instance(CallbackDispatcher, ()) _msg_callbacks = Instance(CallbackDispatcher, ()) #------------------------------------------------------------------------- # (Con/de)structor #------------------------------------------------------------------------- def __init__(self, **kwargs): """Public constructor""" self._model_id = kwargs.pop('model_id', None) super(Widget, self).__init__(**kwargs) Widget._call_widget_constructed(self) self.open() def __del__(self): """Object disposal""" self.close() #------------------------------------------------------------------------- # Properties #------------------------------------------------------------------------- def open(self): """Open a comm to the frontend if one isn't already open.""" if self.comm is None: state, buffer_keys, buffers = self._split_state_buffers( self.get_state()) args = dict(target_name='jupyter.widget', data=state) if self._model_id is not None: args['comm_id'] = self._model_id self.comm = Comm(**args) if buffers: # FIXME: workaround ipykernel missing binary message support in open-on-init # send state with binary elements as second message self.send_state() @observe('comm') def _comm_changed(self, change): """Called when the comm is changed.""" if change['new'] is None: return self._model_id = self.model_id self.comm.on_msg(self._handle_msg) Widget.widgets[self.model_id] = self @property def model_id(self): """Gets the model id of this widget. If a Comm doesn't exist yet, a Comm will be created automagically.""" return self.comm.comm_id #------------------------------------------------------------------------- # Methods #------------------------------------------------------------------------- def close(self): """Close method. Closes the underlying comm. When the comm is closed, all of the widget views are automatically removed from the front-end.""" if self.comm is not None: Widget.widgets.pop(self.model_id, None) self.comm.close() self.comm = None self._ipython_display_ = None def _split_state_buffers(self, state): """Return (state_without_buffers, buffer_keys, buffers) for binary message parts""" buffer_keys, buffers = [], [] for k, v in list(state.items()): if isinstance(v, _binary_types): state.pop(k) buffers.append(v) buffer_keys.append(k) return state, buffer_keys, buffers def send_state(self, key=None): """Sends the widget state, or a piece of it, to the front-end. Parameters ---------- key : unicode, or iterable (optional) A single property's name or iterable of property names to sync with the front-end. """ state = self.get_state(key=key) state, buffer_keys, buffers = self._split_state_buffers(state) msg = {'method': 'update', 'state': state, 'buffers': buffer_keys} self._send(msg, buffers=buffers) def get_state(self, key=None, drop_defaults=False): """Gets the widget state, or a piece of it. Parameters ---------- key : unicode or iterable (optional) A single property's name or iterable of property names to get. Returns ------- state : dict of states metadata : dict metadata for each field: {key: metadata} """ if key is None: keys = self.keys elif isinstance(key, string_types): keys = [key] elif isinstance(key, collections.Iterable): keys = key else: raise ValueError( "key must be a string, an iterable of keys, or None") state = {} traits = self.traits() for k in keys: to_json = self.trait_metadata(k, 'to_json', self._trait_to_json) value = to_json(getattr(self, k), self) if not PY3 and isinstance(traits[k], Bytes) and isinstance( value, bytes): value = memoryview(value) if not drop_defaults or value != traits[k].default_value: state[k] = value return state def set_state(self, sync_data): """Called when a state is received from the front-end.""" # The order of these context managers is important. Properties must # be locked when the hold_trait_notification context manager is # released and notifications are fired. with self._lock_property(**sync_data), self.hold_trait_notifications(): for name in sync_data: if name in self.keys: from_json = self.trait_metadata(name, 'from_json', self._trait_from_json) self.set_trait(name, from_json(sync_data[name], self)) def send(self, content, buffers=None): """Sends a custom msg to the widget model in the front-end. Parameters ---------- content : dict Content of the message to send. buffers : list of binary buffers Binary buffers to send with message """ self._send({"method": "custom", "content": content}, buffers=buffers) def on_msg(self, callback, remove=False): """(Un)Register a custom msg receive callback. Parameters ---------- callback: callable callback will be passed three arguments when a message arrives:: callback(widget, content, buffers) remove: bool True if the callback should be unregistered.""" self._msg_callbacks.register_callback(callback, remove=remove) def on_displayed(self, callback, remove=False): """(Un)Register a widget displayed callback. Parameters ---------- callback: method handler Must have a signature of:: callback(widget, **kwargs) kwargs from display are passed through without modification. remove: bool True if the callback should be unregistered.""" self._display_callbacks.register_callback(callback, remove=remove) def add_traits(self, **traits): """Dynamically add trait attributes to the Widget.""" super(Widget, self).add_traits(**traits) for name, trait in traits.items(): if trait.get_metadata('sync'): self.keys.append(name) self.send_state(name) def notify_change(self, change): """Called when a property has changed.""" # Send the state to the frontend before the user-registered callbacks # are called. name = change['name'] if self.comm is not None and self.comm.kernel is not None: # Make sure this isn't information that the front-end just sent us. if name in self.keys and self._should_send_property( name, change['new']): # Send new state to front-end self.send_state(key=name) LoggingConfigurable.notify_change(self, change) #------------------------------------------------------------------------- # Support methods #------------------------------------------------------------------------- @contextmanager def _lock_property(self, **properties): """Lock a property-value pair. The value should be the JSON state of the property. NOTE: This, in addition to the single lock for all state changes, is flawed. In the future we may want to look into buffering state changes back to the front-end.""" self._property_lock = properties try: yield finally: self._property_lock = {} @contextmanager def hold_sync(self): """Hold syncing any state until the outermost context manager exits""" if self._holding_sync is True: yield else: try: self._holding_sync = True yield finally: self._holding_sync = False self.send_state(self._states_to_send) self._states_to_send.clear() def _should_send_property(self, key, value): """Check the property lock (property_lock)""" to_json = self.trait_metadata(key, 'to_json', self._trait_to_json) if (key in self._property_lock and to_json(value, self) == self._property_lock[key]): return False elif self._holding_sync: self._states_to_send.add(key) return False else: return True # Event handlers @_show_traceback def _handle_msg(self, msg): """Called when a msg is received from the front-end""" data = msg['content']['data'] method = data['method'] # Handle backbone sync methods CREATE, PATCH, and UPDATE all in one. if method == 'backbone': if 'sync_data' in data: # get binary buffers too sync_data = data['sync_data'] for i, k in enumerate(data.get('buffer_keys', [])): sync_data[k] = msg['buffers'][i] self.set_state(sync_data) # handles all methods # Handle a state request. elif method == 'request_state': self.send_state() # Handle a custom msg from the front-end. elif method == 'custom': if 'content' in data: self._handle_custom_msg(data['content'], msg['buffers']) # Catch remainder. else: self.log.error( 'Unknown front-end to back-end widget msg with method "%s"' % method) def _handle_custom_msg(self, content, buffers): """Called when a custom msg is received.""" self._msg_callbacks(self, content, buffers) def _handle_displayed(self, **kwargs): """Called when a view has been displayed for this widget instance""" self._display_callbacks(self, **kwargs) @staticmethod def _trait_to_json(x, self): """Convert a trait value to json.""" return x @staticmethod def _trait_from_json(x, self): """Convert json values to objects.""" return x def _ipython_display_(self, **kwargs): """Called when `IPython.display.display` is called on the widget.""" def loud_error(message): self.log.warn(message) sys.stderr.write('%s\n' % message) # Show view. if self._view_name is not None: validated = Widget._version_validated # Before the user tries to display a widget, validate that the # widget front-end is what is expected. if validated is None: loud_error('Widget Javascript not detected. It may not be ' 'installed or enabled properly.') elif not validated: msg = ('The installed widget Javascript is the wrong version.' ' It must satisfy the semver range %s.' % __frontend_version__) if (Widget._version_frontend): msg += ' The widget Javascript is version %s.' % Widget._version_frontend loud_error(msg) # TODO: delete this sending of a comm message when the display statement # below works. Then add a 'text/plain' mimetype to the dictionary below. self._send({"method": "display"}) # The 'application/vnd.jupyter.widget-view+json' mimetype has not been registered yet. # See the registration process and naming convention at # http://tools.ietf.org/html/rfc6838 # and the currently registered mimetypes at # http://www.iana.org/assignments/media-types/media-types.xhtml. # We don't have a 'text/plain' entry, so this display message will be # will be invisible in the current notebook. data = { 'application/vnd.jupyter.widget-view+json': { 'model_id': self._model_id } } display(data, raw=True) self._handle_displayed(**kwargs) def _send(self, msg, buffers=None): """Sends a message to the model in the front-end.""" if self.comm is not None and self.comm.kernel is not None: self.comm.send(data=msg, buffers=buffers)
class _MultipleSelection(DescriptionWidget, ValueWidget, CoreWidget): """Base class for multiple Selection widgets ``options`` can be specified as a list of values, list of (label, value) tuples, or a dict of {label: value}. The labels are the strings that will be displayed in the UI, representing the actual Python choices, and should be unique. If labels are not specified, they are generated from the values. When programmatically setting the value, a reverse lookup is performed among the options to check that the value is valid. The reverse lookup uses the equality operator by default, but another predicate may be provided via the ``equals`` keyword argument. For example, when dealing with numpy arrays, one may set equals=np.array_equal. """ value = TypedTuple(trait=Any(), help="Selected values") label = TypedTuple(trait=Unicode(), help="Selected labels") index = TypedTuple(trait=Int(), help="Selected indices").tag(sync=True) options = Any( (), help= """Iterable of values or (label, value) pairs that the user can select. The labels are the strings that will be displayed in the UI, representing the actual Python choices, and should be unique. """) _options_full = None # This being read-only means that it cannot be changed from the frontend! _options_labels = TypedTuple( trait=Unicode(), read_only=True, help="The labels for the options.").tag(sync=True) disabled = Bool(help="Enable or disable user changes").tag(sync=True) def __init__(self, *args, **kwargs): self.equals = kwargs.pop('equals', lambda x, y: x == y) # We have to make the basic options bookkeeping consistent # so we don't have errors the first time validators run self._initializing_traits_ = True kwargs['options'] = _exhaust_iterable(kwargs.get('options', ())) self._options_full = _make_options(kwargs['options']) self._propagate_options(None) super().__init__(*args, **kwargs) self._initializing_traits_ = False @validate('options') def _validate_options(self, proposal): proposal.value = _exhaust_iterable(proposal.value) # throws an error if there is a problem converting to full form self._options_full = _make_options(proposal.value) return proposal.value @observe('options') def _propagate_options(self, change): "Unselect any option" options = self._options_full self.set_trait('_options_labels', tuple(i[0] for i in options)) self._options_values = tuple(i[1] for i in options) if self._initializing_traits_ is not True: self.index = () @validate('index') def _validate_index(self, proposal): "Check the range of each proposed index." if all(0 <= i < len(self._options_labels) for i in proposal.value): return proposal.value else: raise TraitError('Invalid selection: index out of bounds') @observe('index') def _propagate_index(self, change): "Propagate changes in index to the value and label properties" label = tuple(self._options_labels[i] for i in change.new) value = tuple(self._options_values[i] for i in change.new) # we check equality so we can avoid validation if possible if self.label != label: self.label = label if self.value != value: self.value = value @validate('value') def _validate_value(self, proposal): "Replace all values with the actual objects in the options list" try: return tuple( findvalue(self._options_values, i, self.equals) for i in proposal.value) except ValueError: raise TraitError('Invalid selection: value not found') @observe('value') def _propagate_value(self, change): index = tuple(self._options_values.index(i) for i in change.new) if self.index != index: self.index = index @validate('label') def _validate_label(self, proposal): if any(i not in self._options_labels for i in proposal.value): raise TraitError('Invalid selection: label not found') return proposal.value @observe('label') def _propagate_label(self, change): index = tuple(self._options_labels.index(i) for i in change.new) if self.index != index: self.index = index def _repr_keys(self): keys = super()._repr_keys() # Include options manually, as it isn't marked as synced: yield from sorted(chain(keys, ('options', )))
class Dazer(Configurable): emb_file = Unicode('None', help='词向量文件路径').tag(config=True) vocabulary_size = Int(400000, help='加载的词汇数量').tag(config=True) embedding_size = Int(50, help='词潜入的纬度').tag(config=True) train_file = Unicode('None', help='训练样本文件').tag(config=True) train_labels = Unicode('None', help='需要训练的分类标签').tag(config=True) all_labels = Unicode('None', help='全分类标签').tag(config=True) test_file = Unicode('None', help='训练样本文件').tag(config=True) ckpt_path = Unicode('None', help='模型保存路径').tag(config=True) test_result_file = Unicode('None', help='测试数据结果').tag(config=True) summary_path = Unicode('None', help='统计标量保存路径').tag(config=True) kernal_num = Int(50, help="滤波器个数").tag(config=True) kernal_width = Int(5, help="滤波器宽度").tag(config=True) max_epochs = Int(10, help="最大训练轮数").tag(config=True) eval_frequency = Int(1, help="将数据写入可视化的频率").tag(config=True) batch_size = Int(16, help="批次大小").tag(config=True) load_model = Bool(False, help="是否加载已经训练的模型").tag(config=True) max_pooling_num = Int(3, help="最池化k值").tag(config=True) decoder_mlp1_num = Int(75, help="解码第一层全连接的神经元个数").tag(config=True) regular_term = Float(0.01, help='正则项系数,对抗分类和相似性回归 共享这个超参数').tag(config=True) adv_learning_rate = Float(0.001, help='对抗分类的学习率').tag(config=True) epsilon = Float(0.00001, help='对抗分类的学习率').tag(config=True) model_learning_rate = Float(0.001, help='相似性模型的学习率').tag(config=True) adv_loss_term = Float(0.2, help='adv损失在总损失中占的比例').tag(config=True) """零样本过滤模型""" def __init__(self, **kwargs): super(Dazer, self).__init__(**kwargs) # 定义类的实例属性 self.word2id, self.id2word, self.emb = we.load_word2vec( self.emb_file, self.vocabulary_size, self.embedding_size) self.class_num = len(self.train_labels.split(',')) # 初始化数据流图 self.g = tf.Graph() self.structure_init() def structure_init(self): """定义数据流图结构""" with self.g.as_default(): # 定义输入 self.input_q = tf.placeholder(dtype=tf.int32, shape=[None, None]) self.input_q_len = tf.placeholder(dtype=tf.float32, shape=[ None, ]) self.input_pos_d = tf.placeholder(dtype=tf.int32, shape=[None, None]) self.input_neg_d = tf.placeholder(dtype=tf.int32, shape=[None, None]) self.input_l = tf.placeholder(dtype=tf.int32, shape=[None, self.class_num]) # 词嵌入 emb_q = tf.nn.embedding_lookup(self.emb, self.input_q) emb_pos_d = tf.nn.embedding_lookup(self.emb, self.input_pos_d) emb_neg_d = tf.nn.embedding_lookup(self.emb, self.input_neg_d) # 生成标签的门向量 class_vec = tf.divide(tf.reduce_sum(emb_q, axis=1), tf.expand_dims(self.input_q_len, axis=-1)) query_gate_weights = weight_init( [self.embedding_size, self.kernal_num], 'gate_weights') query_gate_bias = tf.Variable(initial_value=tf.zeros( self.kernal_num, ), name='gate_bias') # shape:[bath_size, kernal_num] gate_vec = tf.sigmoid( tf.matmul(class_vec, query_gate_weights) + query_gate_bias) rs_gate_vec = tf.expand_dims(gate_vec, axis=1) # 生成文档向量 pos_sub_info = tf.subtract(tf.expand_dims(class_vec, axis=1), emb_pos_d) pos_mul_info = tf.multiply(emb_pos_d, tf.expand_dims(class_vec, axis=1)) conv_pos_input = tf.expand_dims(tf.concat( [emb_pos_d, pos_sub_info, pos_mul_info], -1), axis=-1) neg_sub_info = tf.subtract(tf.expand_dims(class_vec, axis=1), emb_neg_d) neg_mul_info = tf.multiply(emb_neg_d, tf.expand_dims(class_vec, axis=1)) conv_neg_input = tf.expand_dims(tf.concat( [emb_neg_d, neg_sub_info, neg_mul_info], -1), axis=-1) # 卷积操作提取文档窗口特征 pos_conv = tf.layers.conv2d(inputs=conv_pos_input, filters=self.kernal_num, kernel_size=(self.kernal_width, self.embedding_size * 3), strides=(1, self.embedding_size * 3), padding="same", name='doc_conv', trainable=True) neg_conv = tf.layers.conv2d(inputs=conv_neg_input, filters=self.kernal_num, kernel_size=(self.kernal_width, self.embedding_size * 3), strides=(1, self.embedding_size * 3), padding="same", name='doc_conv', trainable=True, reuse=True) # shape=[batch,max_dlen,1,kernal_num] # reshape to [batch,max_dlen,kernal_num] rs_pos_conv = tf.squeeze(pos_conv, [2]) rs_neg_conv = tf.squeeze(neg_conv, [2]) pos_gate_conv = tf.multiply(rs_gate_vec, rs_pos_conv) neg_gate_conv = tf.multiply(rs_gate_vec, rs_neg_conv) # top_k 池化 transpose_pos_gate_conv = tf.transpose( pos_gate_conv, [0, 2, 1]) # reshape to [batch,kernal_num,max_dlen] pos_conv_k_max_pooling, _ = tf.nn.top_k(transpose_pos_gate_conv, self.max_pooling_num) pos_encoder = tf.reshape( pos_conv_k_max_pooling, [-1, self.kernal_num * self.max_pooling_num]) transpose_neg_gate_conv = tf.transpose(neg_gate_conv, [0, 2, 1]) neg_conv_k_max_pooling, _ = tf.nn.top_k(transpose_neg_gate_conv, self.max_pooling_num) neg_encoder = tf.reshape( neg_conv_k_max_pooling, [-1, self.kernal_num * self.max_pooling_num]) pos_decoder_mlp1 = tf.layers.dense(inputs=pos_encoder, units=self.decoder_mlp1_num, activation=tf.nn.tanh, trainable=True, name='decoder_mlp1') neg_encoder_mlp1 = tf.layers.dense(inputs=neg_encoder, units=self.decoder_mlp1_num, activation=tf.nn.tanh, trainable=True, name='decoder_mlp1', reuse=True) self.pos_score = tf.layers.dense(inputs=pos_decoder_mlp1, units=1, activation=tf.nn.tanh, trainable=True, name='decoder_mlp2') neg_score = tf.layers.dense(inputs=neg_encoder_mlp1, units=1, activation=tf.nn.tanh, trainable=True, name='decoder_mlp2', reuse=True) hinge_loss = tf.reduce_mean( tf.maximum(0.0, 1 - self.pos_score + neg_score)) tf.summary.scalar('hinge_loss', hinge_loss) # 对抗学习 adv_weight = weight_init([self.decoder_mlp1_num, self.class_num], 'adv_weights') adv_bias = tf.Variable(initial_value=tf.zeros(self.class_num, ), name='adv_bias') adv_prob = tf.nn.softmax( tf.add(tf.matmul(pos_decoder_mlp1, adv_weight), adv_bias)) adv_prob_log = tf.log(adv_prob) adv_loss = tf.reduce_mean( tf.reduce_sum(tf.multiply(adv_prob_log, tf.cast(self.input_l, tf.float32)), axis=1)) adv_l2_loss = self.regular_term * l2_loss([ v for v in tf.trainable_variables() if 'b' not in v.name and 'adv' in v.name ]) loss_cat = -1 * adv_loss + adv_l2_loss # 动态学习率 self.lr = tf.Variable(self.model_learning_rate, trainable=False, name='learning_rate') self.adv_train_op = tf.train.AdamOptimizer(learning_rate=self.lr, epsilon=self.epsilon)\ .minimize(loss_cat, var_list=[v for v in tf.trainable_variables() if 'adv' in v.name]) tf.summary.scalar('loss_cat', loss_cat) model_l2_loss = self.regular_term * l2_loss([ v for v in tf.trainable_variables() if 'b' not in v.name and 'adv' not in v.name ]) loss = hinge_loss + model_l2_loss + adv_loss * self.adv_loss_term self.model_train_op = tf.train.AdamOptimizer(learning_rate=self.lr, epsilon=self.epsilon)\ .minimize(loss, var_list=[v for v in tf.trainable_variables() if 'adv' not in v.name]) tf.summary.scalar('loss', loss) self.merged = tf.summary.merge_all() def analysis_result(self, sess, test_df): """分析预测结果""" result = dd1_list() for batch in dp.input_scale_data_test(test_df, self.max_pooling_num, batch_size=self.batch_size): query_dict = batch['query_dict'] query_len_dict = batch['query_len_dict'] doc = batch['doc'] l_context = batch['l_context'] l_label = batch['l_label'] # 生成结果字典 result['real_label'] += l_label result['context'] += l_context for label, query in query_dict.items(): score = sess.run(self.pos_score, feed_dict={ self.input_q: query, self.input_pos_d: doc, self.input_q_len: query_len_dict[label] }) result[label] += list(np.reshape(score, (-1, ))) result_df = pd.DataFrame(result) # 自动获取标签 labels = [label for label, query in query_dict.items()] # 写死标签顺序, 其他数据集需要修改此标签顺序 # labels = ['very negative', 'negative', 'neutral', 'positive', 'very positive'] result_df['pre_label'] = result_df[labels].idxmax(axis=1) f1 = metrics.f1_score(result_df['real_label'], result_df['pre_label'], average='weighted') result_report = classification_report(result_df['real_label'], result_df['pre_label']) con_matrix = confusion_matrix(result_df['real_label'], result_df['pre_label'], labels=labels) return result_df, result_report, f1, labels, con_matrix def test(self): """模型预测""" config = tf.ConfigProto() config.gpu_options.allow_growth = True with tf.Session(config=config, graph=self.g) as sess: train_vars = [v for v in tf.trainable_variables()] saver = tf.train.Saver(var_list=train_vars) saver.restore(sess, self.ckpt_path + '-77212') # 是否自动读取最新的一次,还是要手动设置保存的第几次 test_df = pd.read_csv(self.test_file, engine='python') # 打开调试模式 # sess = tf_debug.LocalCLIDebugWrapperSession(sess) # 定义结果 result_df, result_report, f1, labels, con_matrix = self.analysis_result( sess, test_df) logger.info('生成测试集预测统计,并保存具体预测结果:') logger.info('\n' + result_report) result_df.to_csv(self.test_result_file) # def train(self): # """单独一次零样本模型训练""" # config = tf.ConfigProto() # config.gpu_options.allow_growth = True # with tf.Session(config=config, graph=self.g) as sess: # train_vars = [v for v in tf.trainable_variables()] # saver = tf.train.Saver(var_list=train_vars, max_to_keep=3) # sess.run(tf.global_variables_initializer()) # writer = tf.summary.FileWriter(self.summary_path, self.g) # train_df = pd.read_csv(self.train_file, engine='python') # test_df = pd.read_csv(self.test_file, engine='python') # step = 0 # # 打开调试模式 # # sess = tf_debug.LocalCLIDebugWrapperSession(sess) # logger.info('训练批次大小:{}'.format(self.batch_size)) # for epoch in range(int(self.max_epochs)): # if epoch+1 % 25 == 0: # sess.run(self.lr.assign(self.model_learning_rate / 5.0)) # # sess.run(self.lr.assign(self.model_learning_rate * 0.95**epoch)) # for batch in dp.input_scale_data_train(train_df, self.train_labels.split(','), batch_size=self.batch_size, min_len=self.max_pooling_num): # query = batch['query'] # label = batch['label'] # pos_doc = batch['pos_doc'] # neg_doc = batch['neg_doc'] # query_len = batch['query_len'] # merged, ato, mto = sess.run([self.merged, self.adv_train_op, self.model_train_op], feed_dict={self.input_q: query, self.input_l: label, self.input_pos_d: pos_doc, self.input_neg_d: neg_doc, self.input_q_len: query_len}) # # 可视化 # writer.add_summary(merged, global_step=step) # writer.flush() # step += 1 # # 每5次做一次验证并打印 # if epoch % 5 == 0: # logger.info('第{}次训练结果:'.format(epoch+1)) # logger.info('训练样本和零样本的准确率:') # self.analysis_result(sess, train_df) # logger.info('测试样本和零样本的准确率:') # self.analysis_result(sess, test_df) # saver.save(sess, self.ckpt_path, global_step=step) # logger.info('第{}次epoch训练完成。'.format(epoch+1)) # writer.close() # pass def train_all(self): """自动循环训练所有标签,使得每一个标签都做为零样本标签,并观察其泛化效果""" config = tf.ConfigProto() config.gpu_options.allow_growth = True now_time = datetime.now().strftime('%Y%m%d%H%M') train_df = pd.read_csv(self.train_file, engine='python') test_df = pd.read_csv(self.test_file, engine='python') logger.info( '------------------------------------------------------------------------------------------------' ) logger.info('启动全样本训练, 时间:{}'.format(now_time)) for zero_label, train_labels, train_data, test_data in reorganize_data( self.all_labels, train_df, test_df): # 单独对negative零样本标签,进行调参 if zero_label != 'very positive': continue else: # self.regular_term = 0.0005 # negative 正则项系数 # self.regular_term = 0.0001 # very negative 正则项系数 self.regular_term = 0.001 # positive , very positive 正则项系数 self.adv_learning_rate = 0.01 # positive , very positive 学习率 with tf.Session(config=config, graph=self.g) as sess: train_vars = [v for v in tf.trainable_variables()] saver = tf.train.Saver(var_list=train_vars, max_to_keep=3) sess.run(tf.global_variables_initializer()) max_f1 = 0.0 best_test_result_report = '' best_con_matrix_report = '' step = 0 # 保存在特定标签下路径中 writer = tf.summary.FileWriter( self.summary_path.format(now_time, zero_label.replace(' ', '')), self.g) logger.info('开启学习---零样本标签:{}'.format(zero_label)) logger.info( '本次训练embedding_size:{}. max_epochs:{}, batch_size: {}, regular_term:{}, adv_learning_rate:' '{}, epsilon:{}, adv_loss_term:{}'.format( self.embedding_size, self.max_epochs, self.batch_size, self.regular_term, self.adv_learning_rate, self.epsilon, self.adv_loss_term)) # 动态学习率 for epoch in range(int(self.max_epochs)): # 动态学习率 # if epoch + 1 % 25 == 0: # sess.run(self.lr.assign(self.model_learning_rate / 5.0)) # 倍数缩小设置方式 # sess.run(self.lr.assign(self.model_learning_rate * 0.95**epoch)) # 指数设置方式 for batch in dp.input_scale_data_train( train_data, train_labels, batch_size=self.batch_size, min_len=self.max_pooling_num): query = batch['query'] label = batch['label'] pos_doc = batch['pos_doc'] neg_doc = batch['neg_doc'] query_len = batch['query_len'] merged, ato, mto = sess.run( [ self.merged, self.adv_train_op, self.model_train_op ], feed_dict={ self.input_q: query, self.input_l: label, self.input_pos_d: pos_doc, self.input_neg_d: neg_doc, self.input_q_len: query_len }) # 可视化 writer.add_summary(merged, global_step=step) writer.flush() step += 1 # 每5次做一次验证并打印 if epoch % 5 == 0: logger.info('第{}次训练结果:'.format(epoch + 1)) logger.info('训练样本的准确率:') train_result_df, train_result_report, train_f1, train_labels, train_con_matrix = self.analysis_result( sess, train_data) logger.info('\n' + train_result_report) logger.info('测试样本和零样本的准确率:') test_result_df, test_result_report, test_f1, test_labels, test_con_matrix = self.analysis_result( sess, test_data) logger.info('\n' + test_result_report) logger.info('测试样本和零样本的混淆矩阵:{}'.format(test_labels)) test_con_matrix_report = con_matrix_2_str( test_labels, test_con_matrix) logger.info('\n' + test_con_matrix_report) if test_f1 > max_f1: # 保存权重值 saver.save(sess, self.ckpt_path.format( now_time, zero_label.replace(' ', '')), global_step=step) # 记录最好结果的统计报告 和 混淆矩阵 best_test_result_report = test_result_report best_con_matrix_report = test_con_matrix_report # 保存预测的结果 test_result_df.to_csv( self.test_result_file.format( now_time, zero_label.replace(' ', ''))) max_f1 = test_f1 if epoch + 1 == int(self.max_epochs): logger.info('测试集上最好的结果:') logger.info('\n' + best_test_result_report) logger.info('\n' + best_con_matrix_report) logger.info('第{}次epoch训练完成。'.format(epoch + 1)) writer.close()
class DisplayDL1Calib(Tool): name = "ctapipe-display-dl1" description = __doc__ telescope = Int( None, allow_none=True, help="Telescope to view. Set to None to display all telescopes.", ).tag(config=True) extractor_product = traits.enum_trait(ImageExtractor, default="NeighborPeakWindowSum") aliases = Dict( dict( input="EventSource.input_url", max_events="EventSource.max_events", extractor="DisplayDL1Calib.extractor_product", T="DisplayDL1Calib.telescope", O="ImagePlotter.output_path", )) flags = Dict( dict(D=( { "ImagePlotter": { "display": True } }, "Display the photo-electron images on-screen as they are produced.", ))) classes = List([EventSource, ImagePlotter] + traits.classes_with_traits(ImageExtractor)) def __init__(self, **kwargs): super().__init__(**kwargs) self.eventsource = None self.calibrator = None self.plotter = None def setup(self): self.eventsource = self.add_component( EventSource.from_url( get_dataset_path("gamma_test_large.simtel.gz"), parent=self)) self.calibrator = self.add_component(CameraCalibrator(parent=self)) self.plotter = self.add_component(ImagePlotter(parent=self)) def start(self): for event in self.eventsource: self.calibrator(event) tel_list = event.r0.tels_with_data if self.telescope: if self.telescope not in tel_list: continue tel_list = [self.telescope] for telid in tel_list: self.plotter.plot(event, telid) def finish(self): self.plotter.finish()
class DAZER(BaseNN): #params of zeroshot document filtering model embedding_size = Int(300, help="embedding dimension").tag(config=True) vocabulary_size = Int(2000000, help="vocabulary size").tag(config=True) kernal_width = Int(5, help='kernal width').tag(config=True) kernal_num = Int(50, help='number of kernal').tag(config=True) regular_term = Float(0.01, help='param for controlling wight of L2 loss').tag(config=True) maxpooling_num = Int(3, help='number of k-maxpooling').tag(config=True) decoder_mlp1_num = Int(75, help='number of hidden units of first mlp in relevance aggregation part').tag(config=True) decoder_mlp2_num = Int(1, help='number of hidden units of second mlp in relevance aggregation part').tag(config=True) emb_in = Unicode('None', help="initial embedding. Terms should be hashed to ids.").tag(config=True) model_learning_rate = Float(0.001, help="learning rate of model").tag(config=True) adv_learning_rate = Float(0.001, help='learning rate of adv classifier').tag(config=True) epsilon = Float(0.00001, help="Epsilon for Adam").tag(config=True) label_dict_path = Unicode('None', help='label dict path').tag(config=True) word2id_path = Unicode('None', help='word2id path').tag(config=True) train_class_num = Int(16, help='num of class in training data').tag(config=True) adv_term = Float(0.2, help='regular term of adversrial loss').tag(config=True) zsl_num = Int(1, help='num of zeroshot label').tag(config=True) zsl_type = Int(1, help='type of zeroshot label setting').tag(config=True) def __init__(self, **kwargs): #init the DAZER model super(DAZER, self).__init__(**kwargs) print ("trying to load initial embeddings from: ", self.emb_in) if self.emb_in != 'None': self.emb = self.load_word2vec(self.emb_in) self.embeddings = tf.Variable(tf.constant(self.emb, dtype='float32', shape=[self.vocabulary_size + 1, self.embedding_size]),trainable=False) print ("Initialized embeddings with {0}".format(self.emb_in)) else: self.embeddings = tf.Variable(tf.random_uniform([self.vocabulary_size + 1, self.embedding_size], -1.0, 1.0)) #variables of the DAZER model self.query_gate_weight = BaseNN.weight_variable((self.embedding_size, self.kernal_num),'gate_weight') self.query_gate_bias = tf.Variable(initial_value=tf.zeros((self.kernal_num)),name='gate_bias') self.adv_weight = BaseNN.weight_variable((self.decoder_mlp1_num,self.train_class_num),name='adv_weight') self.adv_bias = tf.Variable(initial_value=tf.zeros((1,self.train_class_num)),name='adv_bias') #get the label information to help adversarial learning self.label_dict, self.reverse_label_dict, self.label_list = get_label.get_labels(self.label_dict_path, self.word2id_path) self.label_index_dict = get_label.get_label_index(self.label_list, self.zsl_num, self.zsl_type) def load_word2vec(self, emb_file_path): emb = np.zeros((self.vocabulary_size + 1, self.embedding_size)) nlines = 0 with open(emb_file_path) as f: for line in f: nlines += 1 if nlines == 1: continue items = line.split() tid = int(items[0]) if tid > self.vocabulary_size: print (tid) continue vec = np.array([float(t) for t in items[1:]]) emb[tid, :] = vec if nlines % 20000 == 0: print ("load {0} vectors...".format(nlines)) return emb def gen_adv_query_mask(self, q_ids): q_mask = np.zeros((self.batch_size, self.train_class_num)) for batch_num, b_q_id in enumerate(q_ids): c_name = self.reverse_label_dict[b_q_id] c_index = self.label_index_dict[c_name] q_mask[batch_num][c_index] = 1 return q_mask def get_class_gate(self,class_vec, emb_d): ''' compute the gate in kernal space :param class_vec: avg emb of seed words :param emb_d: emb of doc :return:the class gate [batchsize,d_len,kernal_num] ''' gate1 = tf.expand_dims(tf.matmul(class_vec, self.query_gate_weight), axis=1) bias = tf.expand_dims(self.query_gate_bias,axis=0) gate = tf.add(gate1, bias) return tf.sigmoid(gate) def L2_model_loss(self): all_para = [v for v in tf.trainable_variables() if 'b' not in v.name and 'adv' not in v.name] loss = 0. for each in all_para: loss += tf.nn.l2_loss(each) return loss def L2_adv_loss(self): all_para = [v for v in tf.trainable_variables() if 'b' not in v.name and 'adv' in v.name] loss = 0. for each in all_para: loss += tf.nn.l2_loss(each) return loss def train(self, train_pair_file_path, val_pair_file_path, checkpoint_dir, load_model=False): input_q = tf.placeholder(tf.int32, shape=[self.batch_size,self.max_q_len]) input_pos_d = tf.placeholder(tf.int32, shape=[self.batch_size,self.max_d_len]) input_neg_d = tf.placeholder(tf.int32, shape=[self.batch_size,self.max_d_len]) q_lens = tf.placeholder(tf.float32, shape=[self.batch_size,]) q_mask = tf.placeholder(tf.float32, shape=[self.batch_size,self.max_q_len]) pos_d_mask = tf.placeholder(tf.float32, shape=[self.batch_size,self.max_d_len]) neg_d_mask = tf.placeholder(tf.float32, shape=[self.batch_size,self.max_d_len]) input_q_index = tf.placeholder(tf.int32, shape=[self.batch_size,self.train_class_num]) emb_q = tf.nn.embedding_lookup(self.embeddings,input_q) class_vec_sum = tf.reduce_sum( tf.multiply(emb_q,tf.expand_dims(q_mask,axis=-1)), axis=1 ) #get class vec class_vec = tf.div(class_vec_sum,tf.expand_dims(q_lens,-1)) emb_pos_d = tf.nn.embedding_lookup(self.embeddings,input_pos_d) emb_neg_d = tf.nn.embedding_lookup(self.embeddings,input_neg_d) #get query gate pos_query_gate = self.get_class_gate(class_vec, emb_pos_d) neg_query_gate = self.get_class_gate(class_vec, emb_neg_d) # CNN for document pos_mult_info = tf.multiply(tf.expand_dims(class_vec, axis=1), emb_pos_d) pos_sub_info = tf.expand_dims(class_vec,axis=1) - emb_pos_d pos_conv_input = tf.concat([emb_pos_d,pos_mult_info,pos_sub_info], axis=-1) neg_mult_info = tf.multiply(tf.expand_dims(class_vec, axis=1), emb_neg_d) neg_sub_info = tf.expand_dims(class_vec,axis=1) - emb_neg_d neg_conv_input = tf.concat([emb_neg_d,neg_mult_info,neg_sub_info], axis=-1) #in fact that's 1D conv, but we implement it by conv2d pos_conv = tf.layers.conv2d( inputs = tf.expand_dims(pos_conv_input,axis=-1), filters = self.kernal_num, kernel_size=[self.kernal_width,self.embedding_size*3], strides = [1,self.embedding_size*3], padding = 'SAME', trainable = True, name='doc_conv' ) neg_conv = tf.layers.conv2d( inputs = tf.expand_dims(neg_conv_input,axis=-1), filters = self.kernal_num, kernel_size=[self.kernal_width,self.embedding_size*3], strides = [1,self.embedding_size*3], padding = 'SAME', trainable = True, name='doc_conv', reuse=True ) #shape=[batch,max_dlen,1,kernal_num] #reshape to [batch,max_dlen,kernal_num] rs_pos_conv = tf.squeeze(pos_conv) rs_neg_conv = tf.squeeze(neg_conv) #query_gate elment-wise multiply rs_pos_conv pos_gate_conv = tf.multiply(pos_query_gate, rs_pos_conv) neg_gate_conv = tf.multiply(neg_query_gate, rs_neg_conv) #K-max_pooling #transpose to [batch,knum,dlen],then get max k in each kernal filter transpose_pos_gate_conv = tf.transpose(pos_gate_conv, perm=[0,2,1]) transpose_neg_gate_conv = tf.transpose(neg_gate_conv, perm=[0,2,1]) #shape = [batch,k_num,maxpolling_num] #the k-max pooling here is implemented by function top_k, so the relative position information is ignored pos_kmaxpooling,_ = tf.nn.top_k( input=transpose_pos_gate_conv, k=self.maxpooling_num, ) neg_kmaxpooling,_ = tf.nn.top_k( input=transpose_neg_gate_conv, k=self.maxpooling_num, ) pos_encoder = tf.reshape(pos_kmaxpooling, shape=(self.batch_size,-1)) neg_encoder = tf.reshape(neg_kmaxpooling, shape=(self.batch_size,-1)) pos_decoder_mlp1 = tf.layers.dense( inputs=pos_encoder, units=self.decoder_mlp1_num, activation=tf.nn.tanh, trainable=True, name='decoder_mlp1' ) neg_decoder_mlp1 = tf.layers.dense( inputs=neg_encoder, units=self.decoder_mlp1_num, activation=tf.nn.tanh, trainable=True, name='decoder_mlp1', reuse=True ) pos_decoder_mlp2 = tf.layers.dense( inputs=pos_decoder_mlp1, units=self.decoder_mlp2_num, activation=tf.nn.tanh, trainable=True, name='decoder_mlp2' ) neg_decoder_mlp2 = tf.layers.dense( inputs=neg_decoder_mlp1, units=self.decoder_mlp2_num, activation=tf.nn.tanh, trainable=True, name='decoder_mlp2', reuse=True ) score_pos = pos_decoder_mlp2 score_neg = neg_decoder_mlp2 hinge_loss = tf.reduce_mean(tf.maximum(0.0, 1 - score_pos + score_neg)) adv_prob = tf.nn.softmax(tf.add(tf.matmul(pos_decoder_mlp1, self.adv_weight), self.adv_bias)) log_adv_prob = tf.log(adv_prob) adv_loss = tf.reduce_mean(tf.reduce_sum(tf.multiply(log_adv_prob, tf.cast(input_q_index,tf.float32)), axis=1, keep_dims=True)) L2_adv_loss = self.regular_term*self.L2_adv_loss() #to apply GRL, we use two seperate optimizers for adversarial classifier and the rest part of DAZER #optimizer for adversarial classifier adv_var_list = [v for v in tf.trainable_variables() if 'adv' in v.name] adv_opt = tf.train.AdamOptimizer(learning_rate=self.adv_learning_rate, epsilon=self.epsilon).minimize(loss=(-1 * adv_loss + L2_adv_loss), var_list=adv_var_list) #optimizer for rest part of DAZER model L2_model_loss = self.regular_term*self.L2_model_loss() model_var_list = [v for v in tf.trainable_variables() if 'adv' not in v.name] loss = hinge_loss + L2_model_loss + (adv_loss * self.adv_term) model_opt = tf.train.AdamOptimizer(learning_rate=self.model_learning_rate, epsilon=self.epsilon).minimize(loss = loss, var_list = model_var_list) config = tf.ConfigProto() config.gpu_options.allow_growth = True val_results = [] save_num = 0 save_var = [v for v in tf.trainable_variables()] # Create a local session to run the training. with tf.Session(config=config) as sess: saver = tf.train.Saver(max_to_keep=50,var_list=save_var) start_time = time.time() if not load_model: print ("Initializing a new model...") init = tf.global_variables_initializer() sess.run(init) print('New model initialized!') else: #to load trained model, and keep training #remember to change the name of ckpt file init = tf.global_variables_initializer() sess.run(init) saver.restore(sess, checkpoint_dir+'/zsl25.ckpt') print ("model loaded!") # Loop through training steps. step = 0 loss_list = [] for epoch in range(int(self.max_epochs)): epoch_val_loss = 0 epoch_loss = 0 epoch_hinge_loss = 0. epoch_adv_loss = 0 epoch_s = time.time() pair_stream = open(train_pair_file_path) for BATCH in self.data_generator.pairwise_reader(pair_stream, self.batch_size): step += 1 X, Y = BATCH query = X[u'q'] str_query = X[u'q_str'] q_index = self.gen_adv_query_mask(str_query) pos_doc = X[u'd'] neg_doc = X[u'd_aux'] train_q_lens = X[u'q_lens'] M_query = self.gen_query_mask(query) M_pos = self.gen_doc_mask(pos_doc) M_neg = self.gen_doc_mask(neg_doc) if X[u'q_lens'].shape[0] != self.batch_size: continue train_feed_dict = {input_q:query, input_pos_d:pos_doc, q_lens:train_q_lens, input_neg_d:neg_doc, q_mask:M_query, pos_d_mask:M_pos, neg_d_mask:M_neg, input_q_index: q_index} _1,l,hinge_l,_2,adv_l = sess.run([model_opt,loss,hinge_loss,adv_opt,adv_loss], feed_dict=train_feed_dict) epoch_loss += l epoch_hinge_loss += hinge_l epoch_adv_loss += adv_l if (epoch + 1) % self.eval_frequency == 0: #after eval_frequency epochs we run model on val dataset val_start = time.time() val_pair_stream = open(val_pair_file_path) for BATCH in self.val_data_generator.pairwise_reader(val_pair_stream, self.batch_size): X_val,Y_val = BATCH query = X_val[u'q'] pos_doc = X_val[u'd'] neg_doc = X_val[u'd_aux'] val_q_lens = X_val[u'q_lens'] M_query = self.gen_query_mask(query) M_pos = self.gen_doc_mask(pos_doc) M_neg = self.gen_doc_mask(neg_doc) if X_val[u'q'].shape[0] != self.batch_size: continue train_feed_dict = {input_q:query, input_pos_d:pos_doc, input_neg_d:neg_doc, q_lens:val_q_lens, q_mask:M_query, pos_d_mask:M_pos, neg_d_mask:M_neg} # Run the graph and fetch some of the nodes. v_loss = sess.run(hinge_loss, feed_dict=train_feed_dict) epoch_val_loss += v_loss val_results.append(epoch_val_loss) val_end = time.time() print('---Validation:epoch %d, %.1f ms , val_loss are %f' % (epoch+1,val_end-val_start,epoch_val_loss)) sys.stdout.flush() loss_list.append(epoch_loss) epoch_e = time.time() print('---Train:%d epoches cost %f seconds, hinge cost = %f model cost = %f, adv cost = %f...'%(epoch+1,epoch_e-epoch_s,epoch_hinge_loss, epoch_loss,epoch_adv_loss)) # save model after checkpoint_steps epochs if (epoch+1)%self.checkpoint_steps == 0: save_num += 1 saver.save(sess, checkpoint_dir + 'zsl'+str(epoch+1)+'.ckpt') pair_stream.close() with open('save_training_loss.txt','w') as f: for index,_loss in enumerate(loss_list): f.write('epoch'+str(index+1)+', loss:'+str(_loss)+'\n') with open('save_val_cost.txt','w') as f: for index, v_l in enumerate(val_results): f.write('epoch'+str((index+1)*self.eval_frequency)+' val loss:'+str(v_l)+'\n') # end training end_time = time.time() print('All costs %f seconds...'%(end_time-start_time)) def test(self, test_point_file_path, test_size, output_file_path, checkpoint_dir=None, load_model=False): input_q = tf.placeholder(tf.int32, shape=[self.batch_size,self.max_q_len]) input_pos_d = tf.placeholder(tf.int32, shape=[self.batch_size,self.max_d_len]) q_lens = tf.placeholder(tf.float32, shape=[self.batch_size,]) q_mask = tf.placeholder(tf.float32, shape=[self.batch_size,self.max_q_len]) pos_d_mask = tf.placeholder(tf.float32, shape=[self.batch_size,self.max_d_len]) emb_q = tf.nn.embedding_lookup(self.embeddings,input_q) class_vec_sum = tf.reduce_sum( tf.multiply(emb_q,tf.expand_dims(q_mask,axis=-1)), axis=1 ) class_vec = tf.div(class_vec_sum,tf.expand_dims(q_lens,axis=-1)) emb_pos_d = tf.nn.embedding_lookup(self.embeddings,input_pos_d) #get query gate query_gate = self.get_class_gate(class_vec, emb_pos_d) pos_mult_info = tf.multiply(tf.expand_dims(class_vec, axis=1), emb_pos_d) pos_sub_info = tf.expand_dims(class_vec, axis=1) - emb_pos_d pos_conv_input = tf.concat([emb_pos_d,pos_mult_info, pos_sub_info], axis=-1) # CNN for document pos_conv = tf.layers.conv2d( inputs = tf.expand_dims(pos_conv_input,axis=-1), filters = self.kernal_num, kernel_size=[self.kernal_width,self.embedding_size*3], strides = [1,self.embedding_size*3], padding = 'SAME', trainable = True, name='doc_conv' ) #shape=[batch,max_dlen,1,kernal_num] #reshape to [batch,max_dlen,kernal_num] rs_pos_conv = tf.squeeze(pos_conv) #query_gate elment-wise multiply rs_pos_conv #[batch,kernal_num] , [batch,max_dlen,kernal_num] pos_gate_conv = tf.multiply(query_gate, rs_pos_conv) #K-max_pooling #transpose to [batch,knum,dlen],then get max k in each kernal filter transpose_pos_gate_conv = tf.transpose(pos_gate_conv, perm=[0,2,1]) #[batch,k_num,maxpolling_num] pos_kmaxpooling,_ = tf.nn.top_k( input=transpose_pos_gate_conv, k=self.maxpooling_num, ) pos_encoder = tf.reshape(pos_kmaxpooling, shape=(self.batch_size,-1)) pos_decoder_mlp1 = tf.layers.dense( inputs=pos_encoder, units=self.decoder_mlp1_num, activation=tf.nn.tanh, trainable=True, name='decoder_mlp1' ) pos_decoder_mlp2 = tf.layers.dense( inputs=pos_decoder_mlp1, units=self.decoder_mlp2_num, activation=tf.nn.tanh, trainable=True, name='decoder_mlp2' ) score_pos = pos_decoder_mlp2 config = tf.ConfigProto() config.gpu_options.allow_growth = True save_var = [v for v in tf.trainable_variables()] # Create a local session to run the testing. for i in range(int(self.max_epochs/self.checkpoint_steps)): with tf.Session(config=config) as sess: test_point_stream = open(test_point_file_path) outfile = open(output_file_path+'-epoch'+str(self.checkpoint_steps*(i+1))+'.txt', 'w') saver = tf.train.Saver(var_list=save_var) if load_model: p = checkpoint_dir + 'zsl'+str(self.checkpoint_steps*(i+1))+'.ckpt' init = tf.global_variables_initializer() sess.run(init) saver.restore(sess, p) print ("data loaded!") else: init = tf.global_variables_initializer() sess.run(init) # Loop through training steps. for b in range(int(np.ceil(float(test_size)/self.batch_size))): X = next(self.test_data_generator.test_pairwise_reader(test_point_stream, self.batch_size)) if(X[u'q'].shape[0] != self.batch_size): continue query = X[u'q'] pos_doc = X[u'd'] test_q_lens = X[u'q_lens'] M_query = self.gen_query_mask(query) M_pos = self.gen_doc_mask(pos_doc) test_feed_dict = {input_q: query, input_pos_d: pos_doc, q_lens: test_q_lens, q_mask: M_query, pos_d_mask: M_pos} # Run the graph and fetch some of the nodes. scores = sess.run(score_pos, feed_dict=test_feed_dict) for score in scores: outfile.write('{0}\n'.format(score[0])) outfile.close() test_point_stream.close()
class DockerSpawner(Spawner): """A Spawner for JupyterHub that runs each user's server in a separate docker container""" _executor = None @property def executor(self): """single global executor""" cls = self.__class__ if cls._executor is None: cls._executor = ThreadPoolExecutor(1) return cls._executor _client = None @property def client(self): """single global client instance""" cls = self.__class__ if cls._client is None: kwargs = {"version": "auto"} if self.tls_config: kwargs["tls"] = docker.tls.TLSConfig(**self.tls_config) kwargs.update(kwargs_from_env()) kwargs.update(self.client_kwargs) client = docker.APIClient(**kwargs) cls._client = client return cls._client # notice when user has set the command # default command is that of the container, # but user can override it via config _user_set_cmd = False @observe("cmd") def _cmd_changed(self, change): self._user_set_cmd = True object_id = Unicode() # the type of object we create object_type = "container" # the field containing the object id object_id_key = "Id" @property def container_id(self): """alias for object_id""" return self.object_id @property def container_name(self): """alias for object_name""" return self.object_name # deprecate misleading container_ip, since # it is not the ip in the container, # but the host ip of the port forwarded to the container # when use_internal_ip is False container_ip = Unicode("127.0.0.1", config=True) @observe("container_ip") def _container_ip_deprecated(self, change): self.log.warning( "DockerSpawner.container_ip is deprecated in dockerspawner-0.9." " Use DockerSpawner.host_ip to specify the host ip that is forwarded to the container" ) self.host_ip = change.new host_ip = Unicode( "127.0.0.1", help= """The ip address on the host on which to expose the container's port Typically 127.0.0.1, but can be public interfaces as well in cases where the Hub and/or proxy are on different machines from the user containers. Only used when use_internal_ip = False. """, config=True, ) @default('host_ip') def _default_host_ip(self): docker_host = os.getenv('DOCKER_HOST') if docker_host: urlinfo = urlparse(docker_host) if urlinfo.scheme == 'tcp': return urlinfo.hostname return '127.0.0.1' # unlike container_ip, container_port is the internal port # on which the server is bound. container_port = Int(8888, min=1, max=65535, config=True) @observe("container_port") def _container_port_changed(self, change): self.log.warning( "DockerSpawner.container_port is deprecated in dockerspawner 0.9." " Use DockerSpawner.port") self.port = change.new # fix default port to 8888, used in the container @default("port") def _port_default(self): return 8888 # default to listening on all-interfaces in the container @default("ip") def _ip_default(self): return "0.0.0.0" container_image = Unicode("jupyterhub/singleuser:%s" % _jupyterhub_xy, config=True) @observe("container_image") def _container_image_changed(self, change): self.log.warning( "DockerSpawner.container_image is deprecated in dockerspawner 0.9." " Use DockerSpawner.image") self.image = change.new image = Unicode( "jupyterhub/singleuser:%s" % _jupyterhub_xy, config=True, help="""The image to use for single-user servers. This image should have the same version of jupyterhub as the Hub itself installed. If the default command of the image does not launch jupyterhub-singleuser, set `c.Spawner.cmd` to launch jupyterhub-singleuser, e.g. Any of the jupyter docker-stacks should work without additional config, as long as the version of jupyterhub in the image is compatible. """, ) image_whitelist = Union( [Any(), Dict(), List()], default_value={}, config=True, help=""" List or dict of images that users can run. If specified, users will be presented with a form from which they can select an image to run. If a dictionary, the keys will be the options presented to users and the values the actual images that will be launched. If a list, will be cast to a dictionary where keys and values are the same (i.e. a shortcut for presenting the actual images directly to users). If a callable, will be called with the Spawner instance as its only argument. The user is accessible as spawner.user. The callable should return a dict or list as above. """, ) @validate('image_whitelist') def _image_whitelist_dict(self, proposal): """cast image_whitelist to a dict If passing a list, cast it to a {item:item} dict where the keys and values are the same. """ whitelist = proposal.value if isinstance(whitelist, list): whitelist = {item: item for item in whitelist} return whitelist def _get_image_whitelist(self): """Evaluate image_whitelist callable Or return the whitelist as-is if it's already a dict """ if callable(self.image_whitelist): whitelist = self.image_whitelist(self) if not isinstance(whitelist, dict): # always return a dict whitelist = {item: item for item in whitelist} return whitelist return self.image_whitelist @default('options_form') def _default_options_form(self): image_whitelist = self._get_image_whitelist() if len(image_whitelist) <= 1: # default form only when there are images to choose from return '' # form derived from wrapspawner.ProfileSpawner option_t = '<option value="{image}" {selected}>{image}</option>' options = [ option_t.format(image=image, selected='selected' if image == self.image else '') for image in image_whitelist ] return """ <label for="image">Select an image:</label> <select class="form-control" name="image" required autofocus> {options} </select> """.format(options=options) def options_from_form(self, formdata): """Turn options formdata into user_options""" options = {} if 'image' in formdata: options['image'] = formdata['image'][0] return options container_prefix = Unicode(config=True, help="DEPRECATED in 0.10. Use prefix") container_name_template = Unicode( config=True, help="DEPRECATED in 0.10. Use name_template") @observe("container_name_template", "container_prefix") def _deprecate_container_alias(self, change): new_name = change.name[len("container_"):] setattr(self, new_name, change.new) prefix = Unicode( "jupyter", config=True, help=dedent(""" Prefix for container names. See name_template for full container name for a particular user's server. """), ) name_template = Unicode( "{prefix}-{username}", config=True, help=dedent(""" Name of the container or service: with {username}, {imagename}, {prefix} replacements. The default name_template is <prefix>-<username> for backward compatibility. """), ) client_kwargs = Dict( config=True, help= "Extra keyword arguments to pass to the docker.Client constructor.", ) volumes = Dict( config=True, help=dedent(""" Map from host file/directory to container (guest) file/directory mount point and (optionally) a mode. When specifying the guest mount point (bind) for the volume, you may use a dict or str. If a str, then the volume will default to a read-write (mode="rw"). With a dict, the bind is identified by "bind" and the "mode" may be one of "rw" (default), "ro" (read-only), "z" (public/shared SELinux volume label), and "Z" (private/unshared SELinux volume label). If format_volume_name is not set, default_format_volume_name is used for naming volumes. In this case, if you use {username} in either the host or guest file/directory path, it will be replaced with the current user's name. """), ) read_only_volumes = Dict( config=True, help=dedent(""" Map from host file/directory to container file/directory. Volumes specified here will be read-only in the container. If format_volume_name is not set, default_format_volume_name is used for naming volumes. In this case, if you use {username} in either the host or guest file/directory path, it will be replaced with the current user's name. """), ) format_volume_name = Any( help= """Any callable that accepts a string template and a DockerSpawner instance as parameters in that order and returns a string. Reusable implementations should go in dockerspawner.VolumeNamingStrategy, tests should go in ... """).tag(config=True) def default_format_volume_name(template, spawner): return template.format(username=spawner.user.name) @default("format_volume_name") def _get_default_format_volume_name(self): return default_format_volume_name use_docker_client_env = Bool( True, config=True, help="DEPRECATED. Docker env variables are always used if present.", ) @observe("use_docker_client_env") def _client_env_changed(self): self.log.warning( "DockerSpawner.use_docker_client_env is deprecated and ignored." " Docker environment variables are always used if defined.") tls_config = Dict( config=True, help="""Arguments to pass to docker TLS configuration. See docker.client.TLSConfig constructor for options. """, ) tls = tls_verify = tls_ca = tls_cert = tls_key = tls_assert_hostname = Any( config=True, help= """DEPRECATED. Use DockerSpawner.tls_config dict to set any TLS options.""", ) @observe("tls", "tls_verify", "tls_ca", "tls_cert", "tls_key", "tls_assert_hostname") def _tls_changed(self, change): self.log.warning( "%s config ignored, use %s.tls_config dict to set full TLS configuration.", change.name, self.__class__.__name__, ) remove_containers = Bool( False, config=True, help="DEPRECATED in DockerSpawner 0.10. Use .remove") @observe("remove_containers") def _deprecate_remove_containers(self, change): # preserve remove_containers alias to .remove self.remove = change.new remove = Bool( False, config=True, help=""" If True, delete containers when servers are stopped. This will destroy any data in the container not stored in mounted volumes. """, ) @property def will_resume(self): # indicate that we will resume, # so JupyterHub >= 0.7.1 won't cleanup our API token return not self.remove extra_create_kwargs = Dict( config=True, help="Additional args to pass for container create") extra_host_config = Dict( config=True, help="Additional args to create_host_config for container create") _docker_safe_chars = set(string.ascii_letters + string.digits + "-") _docker_escape_char = "_" hub_ip_connect = Unicode( config=True, help=dedent(""" If set, DockerSpawner will configure the containers to use the specified IP to connect the hub api. This is useful when the hub_api is bound to listen on all ports or is running inside of a container. """), ) @observe("hub_ip_connect") def _ip_connect_changed(self, change): if jupyterhub.version_info >= (0, 8): warnings.warn( "DockerSpawner.hub_ip_connect is no longer needed with JupyterHub 0.8." " Use JupyterHub.hub_connect_ip instead.", DeprecationWarning, ) use_internal_ip = Bool( False, config=True, help=dedent(""" Enable the usage of the internal docker ip. This is useful if you are running jupyterhub (as a container) and the user containers within the same docker network. E.g. by mounting the docker socket of the host into the jupyterhub container. Default is True if using a docker network, False if bridge or host networking is used. """), ) @default("use_internal_ip") def _default_use_ip(self): # setting network_name to something other than bridge or host implies use_internal_ip if self.network_name not in {"bridge", "host"}: return True else: return False links = Dict( config=True, help=dedent(""" Specify docker link mapping to add to the container, e.g. links = {'jupyterhub': 'jupyterhub'} If the Hub is running in a Docker container, this can simplify routing because all traffic will be using docker hostnames. """), ) network_name = Unicode( "bridge", config=True, help=dedent(""" Run the containers on this docker network. If it is an internal docker network, the Hub should be on the same network, as internal docker IP addresses will be used. For bridge networking, external ports will be bound. """), ) @property def tls_client(self): """A tuple consisting of the TLS client certificate and key if they have been provided, otherwise None. """ if self.tls_cert and self.tls_key: return (self.tls_cert, self.tls_key) return None @property def volume_mount_points(self): """ Volumes are declared in docker-py in two stages. First, you declare all the locations where you're going to mount volumes when you call create_container. Returns a sorted list of all the values in self.volumes or self.read_only_volumes. """ return sorted([value["bind"] for value in self.volume_binds.values()]) @property def volume_binds(self): """ The second half of declaring a volume with docker-py happens when you actually call start(). The required format is a dict of dicts that looks like: { host_location: {'bind': container_location, 'mode': 'rw'} } mode may be 'ro', 'rw', 'z', or 'Z'. """ binds = self._volumes_to_binds(self.volumes, {}) return self._volumes_to_binds(self.read_only_volumes, binds, mode="ro") _escaped_name = None @property def escaped_name(self): """Escape the username so it's safe for docker objects""" if self._escaped_name is None: self._escaped_name = escape( self.user.name, safe=self._docker_safe_chars, escape_char=self._docker_escape_char, ) return self._escaped_name object_id = Unicode(allow_none=True) @property def object_name(self): """Render the name of our container/service using name_template""" escaped_image = self.image.replace("/", "_") server_name = getattr(self, "name", "") d = { "username": self.escaped_name, "imagename": escaped_image, "servername": server_name, "prefix": self.prefix, } return self.name_template.format(**d) def load_state(self, state): super(DockerSpawner, self).load_state(state) if "container_id" in state: # backward-compatibility for dockerspawner < 0.10 self.object_id = state.get("container_id") else: self.object_id = state.get("object_id", "") def get_state(self): state = super(DockerSpawner, self).get_state() if self.object_id: state["object_id"] = self.object_id return state def _public_hub_api_url(self): proto, path = self.hub.api_url.split("://", 1) ip, rest = path.split(":", 1) return "{proto}://{ip}:{rest}".format(proto=proto, ip=self.hub_ip_connect, rest=rest) def _env_keep_default(self): """Don't inherit any env from the parent process""" return [] def get_args(self): args = super().get_args() if self.hub_ip_connect: # JupyterHub 0.7 specifies --hub-api-url # on the command-line, which is hard to update for idx, arg in enumerate(list(args)): if arg.startswith("--hub-api-url="): args.pop(idx) break args.append("--hub-api-url=%s" % self._public_hub_api_url()) return args def _docker(self, method, *args, **kwargs): """wrapper for calling docker methods to be passed to ThreadPoolExecutor """ m = getattr(self.client, method) return m(*args, **kwargs) def docker(self, method, *args, **kwargs): """Call a docker method in a background thread returns a Future """ return self.executor.submit(self._docker, method, *args, **kwargs) @gen.coroutine def poll(self): """Check for my id in `docker ps`""" container = yield self.get_object() if not container: self.log.warning("Container not found: %s", self.container_name) return 0 container_state = container["State"] self.log.debug("Container %s status: %s", self.container_id[:7], pformat(container_state)) if container_state["Running"]: return None else: return ("ExitCode={ExitCode}, " "Error='{Error}', " "FinishedAt={FinishedAt}".format(**container_state)) @gen.coroutine def get_object(self): self.log.debug("Getting container '%s'", self.object_name) try: obj = yield self.docker("inspect_%s" % self.object_type, self.object_name) self.object_id = obj[self.object_id_key] except APIError as e: if e.response.status_code == 404: self.log.info("%s '%s' is gone", self.object_type.title(), self.object_name) obj = None # my container is gone, forget my id self.object_id = "" elif e.response.status_code == 500: self.log.info( "%s '%s' is on unhealthy node", self.object_type.title(), self.object_name, ) obj = None # my container is unhealthy, forget my id self.object_id = "" else: raise return obj @gen.coroutine def get_command(self): """Get the command to run (full command + args)""" if self._user_set_cmd: cmd = self.cmd else: image_info = yield self.docker("inspect_image", self.image) cmd = image_info["Config"]["Cmd"] return cmd + self.get_args() @gen.coroutine def remove_object(self): self.log.info("Removing %s %s", self.object_type, self.object_id) # remove the container, as well as any associated volumes yield self.docker("remove_" + self.object_type, self.object_id, v=True) @gen.coroutine def check_image_whitelist(self, image): image_whitelist = self._get_image_whitelist() if not image_whitelist: return image if image not in image_whitelist: raise web.HTTPError( 400, "Image %s not in whitelist: %s" % (image, ', '.join(image_whitelist)), ) # resolve image alias to actual image name return image_whitelist[image] @gen.coroutine def create_object(self): """Create the container/service object""" # image priority: # 1. user options (from spawn options form) # 2. self.image from config image_option = self.user_options.get('image') if image_option: # save choice in self.image self.image = yield self.check_image_whitelist(image_option) create_kwargs = dict( image=self.image, environment=self.get_env(), volumes=self.volume_mount_points, name=self.container_name, command=(yield self.get_command()), ) # ensure internal port is exposed create_kwargs["ports"] = {"%i/tcp" % self.port: None} create_kwargs.update(self.extra_create_kwargs) # build the dictionary of keyword arguments for host_config host_config = dict(binds=self.volume_binds, links=self.links) if getattr(self, "mem_limit", None) is not None: # If jupyterhub version > 0.7, mem_limit is a traitlet that can # be directly configured. If so, use it to set mem_limit. # this will still be overriden by extra_host_config host_config["mem_limit"] = self.mem_limit if not self.use_internal_ip: host_config["port_bindings"] = {self.port: (self.host_ip, )} host_config.update(self.extra_host_config) host_config.setdefault("network_mode", self.network_name) self.log.debug("Starting host with config: %s", host_config) host_config = self.client.create_host_config(**host_config) create_kwargs.setdefault("host_config", {}).update(host_config) # create the container obj = yield self.docker("create_container", **create_kwargs) return obj @gen.coroutine def start_object(self): """Actually start the container/service e.g. calling `docker start` """ return self.docker("start", self.container_id) @gen.coroutine def stop_object(self): """Actually start the container/service e.g. calling `docker start` """ return self.docker("stop", self.container_id) @gen.coroutine def start(self, image=None, extra_create_kwargs=None, extra_host_config=None): """Start the single-user server in a docker container. Additional arguments to create/host config/etc. can be specified via .extra_create_kwargs and .extra_host_config attributes. """ if image: self.log.warning("Specifying image via .start args is deprecated") self.image = image if extra_create_kwargs: self.log.warning( "Specifying extra_create_kwargs via .start args is deprecated") self.extra_create_kwargs.update(extra_create_kwargs) if extra_host_config: self.log.warning( "Specifying extra_host_config via .start args is deprecated") self.extra_host_config.update(extra_host_config) image = self.image obj = yield self.get_object() if obj and self.remove: self.log.warning( "Removing %s that should have been cleaned up: %s (id: %s)", self.object_type, self.object_name, self.object_id[:7], ) yield self.remove_object() obj = None if obj is None: obj = yield self.create_object() self.object_id = obj[self.object_id_key] self.log.info( "Created %s %s (id: %s) from image %s", self.object_type, self.object_name, self.object_id[:7], self.image, ) else: self.log.info( "Found existing %s %s (id: %s)", self.object_type, self.object_name, self.object_id[:7], ) # Handle re-using API token. # Get the API token from the environment variables # of the running container: for line in obj["Config"]["Env"]: if line.startswith( ("JPY_API_TOKEN=", "JUPYTERHUB_API_TOKEN=")): self.api_token = line.split("=", 1)[1] break # TODO: handle unpause self.log.info( "Starting %s %s (id: %s)", self.object_type, self.object_name, self.container_id[:7], ) # start the container yield self.start_object() ip, port = yield self.get_ip_and_port() if jupyterhub.version_info < (0, 7): # store on user for pre-jupyterhub-0.7: self.user.server.ip = ip self.user.server.port = port # jupyterhub 0.7 prefers returning ip, port: return (ip, port) @gen.coroutine def get_ip_and_port(self): """Queries Docker daemon for container's IP and port. If you are using network_mode=host, you will need to override this method as follows:: @gen.coroutine def get_ip_and_port(self): return self.host_ip, self.port You will need to make sure host_ip and port are correct, which depends on the route to the container and the port it opens. """ if self.use_internal_ip: resp = yield self.docker("inspect_container", self.container_id) network_settings = resp["NetworkSettings"] if "Networks" in network_settings: ip = self.get_network_ip(network_settings) else: # Fallback for old versions of docker (<1.9) without network management ip = network_settings["IPAddress"] port = self.port else: resp = yield self.docker("port", self.container_id, self.port) if resp is None: raise RuntimeError("Failed to get port info for %s" % self.container_id) ip = resp[0]["HostIp"] port = int(resp[0]["HostPort"]) if ip == "0.0.0.0": ip = urlparse(self.client.base_url).hostname if ip == "localnpipe": ip = "localhost" return ip, port def get_network_ip(self, network_settings): networks = network_settings["Networks"] if self.network_name not in networks: raise Exception( "Unknown docker network '{network}'." " Did you create it with `docker network create <name>`?". format(network=self.network_name)) network = networks[self.network_name] ip = network["IPAddress"] return ip @gen.coroutine def stop(self, now=False): """Stop the container Consider using pause/unpause when docker-py adds support """ self.log.info( "Stopping %s %s (id: %s)", self.object_type, self.object_name, self.object_id[:7], ) yield self.stop_object() if self.remove: yield self.remove_object() self.clear_state() def _volumes_to_binds(self, volumes, binds, mode="rw"): """Extract the volume mount points from volumes property. Returns a dict of dict entries of the form:: {'/host/dir': {'bind': '/guest/dir': 'mode': 'rw'}} """ def _fmt(v): return self.format_volume_name(v, self) for k, v in volumes.items(): m = mode if isinstance(v, dict): if "mode" in v: m = v["mode"] v = v["bind"] binds[_fmt(k)] = {"bind": _fmt(v), "mode": m} return binds
class SwarmSpawner(Spawner): """A Spawner for JupyterHub using Docker Engine in Swarm mode """ _executor = None @property def executor(self, max_workers=1): """single global executor""" cls = self.__class__ if cls._executor is None: cls._executor = ThreadPoolExecutor(max_workers) return cls._executor _client = None @property def client(self): """single global client instance""" cls = self.__class__ if cls._client is None: kwargs = {} if self.tls_config: kwargs['tls'] = docker.tls.TLSConfig(**self.tls_config) kwargs.update(kwargs_from_env()) client = docker.APIClient(version='auto', **kwargs) cls._client = client return cls._client service_id = Unicode() service_port = Int(8888, min=1, max=65535, config=True) service_image = Unicode("jupyterhub/singleuser", config=True) service_prefix = Unicode("jupyter", config=True, help=dedent(""" Prefix for service names. The full service name for a particular user will be <prefix>-<hash(username)>-<server_name>. """)) tls_config = Dict( config=True, help=dedent("""Arguments to pass to docker TLS configuration. Check for more info: http://docker-py.readthedocs.io/en/stable/tls.html """)) container_spec = Dict({}, config=True, help="Params to create the service") resource_spec = Dict({}, config=True, help="Params about cpu and memory limits") placement = List( [], config=True, help=dedent("""List of placement constraints into the swarm """)) networks = List( [], config=True, help=dedent("""Additional args to create_host_config for service create """)) use_user_options = Bool( False, config=True, help=dedent("""the spawner will use the dict passed through the form or as json body when using the Hub Api """)) jupyterhub_service_name = Unicode( config=True, help=dedent("""Name of the service running the JupyterHub """)) @property def tls_client(self): """A tuple consisting of the TLS client certificate and key if they have been provided, otherwise None. """ if self.tls_cert and self.tls_key: return (self.tls_cert, self.tls_key) return None _service_owner = None @property def service_owner(self): if self._service_owner is None: m = hashlib.md5() m.update(self.user.name.encode('utf-8')) self._service_owner = m.hexdigest() return self._service_owner @property def service_name(self): """ Service name inside the Docker Swarm service_suffix should be a numerical value unique for user {service_prefix}-{service_owner}-{service_suffix} """ if hasattr(self, "server_name") and self.server_name: server_name = self.server_name else: server_name = 1 return "{}-{}-{}".format(self.service_prefix, self.service_owner, server_name) def load_state(self, state): super().load_state(state) self.service_id = state.get('service_id', '') def get_state(self): state = super().get_state() if self.service_id: state['service_id'] = self.service_id return state def _env_keep_default(self): """it's called in traitlets. It's a special method name. Don't inherit any env from the parent process""" return [] def _public_hub_api_url(self): proto, path = self.hub.api_url.split('://', 1) _, rest = path.split(':', 1) return '{proto}://{name}:{rest}'.format( proto=proto, name=self.jupyterhub_service_name, rest=rest) def get_env(self): env = super().get_env() env.update( dict(JPY_USER=self.user.name, JPY_COOKIE_NAME=self.user.server.cookie_name, JPY_BASE_URL=self.user.server.base_url, JPY_HUB_PREFIX=self.hub.server.base_url)) if self.notebook_dir: env['NOTEBOOK_DIR'] = self.notebook_dir env['JPY_HUB_API_URL'] = self._public_hub_api_url() return env def _docker(self, method, *args, **kwargs): """wrapper for calling docker methods to be passed to ThreadPoolExecutor """ m = getattr(self.client, method) return m(*args, **kwargs) def docker(self, method, *args, **kwargs): """Call a docker method in a background thread returns a Future """ return self.executor.submit(self._docker, method, *args, **kwargs) @gen.coroutine def poll(self): """Check for a task state like `docker service ps id`""" service = yield self.get_service() if not service: self.log.warn("Docker service not found") return 0 task_filter = {'service': service['Spec']['Name']} tasks = yield self.docker('tasks', task_filter) running_task = None for task in tasks: task_state = task['Status']['State'] self.log.debug( "Task %s of Docker service %s status: %s", task['ID'][:7], self.service_id[:7], pformat(task_state), ) if task_state == 'running': # there should be at most one running task running_task = task if running_task is not None: return None else: return 1 @gen.coroutine def get_service(self): self.log.debug("Getting Docker service '%s'", self.service_name) try: service = yield self.docker('inspect_service', self.service_name) self.service_id = service['ID'] except APIError as err: if err.response.status_code == 404: self.log.info("Docker service '%s' is gone", self.service_name) service = None # Docker service is gone, remove service id self.service_id = '' elif err.response.status_code == 500: self.log.info("Docker Swarm Server error") service = None # Docker service is unhealthy, remove the service_id self.service_id = '' else: raise return service @gen.coroutine def start(self): """Start the single-user server in a docker service. You can specify the params for the service through jupyterhub_config.py or using the user_options """ # https://github.com/jupyterhub/jupyterhub/blob/master/jupyterhub/user.py#L202 # By default jupyterhub calls the spawner passing user_options if self.use_user_options: user_options = self.user_options else: user_options = {} self.log.warn("user_options: {}".format(user_options)) service = yield self.get_service() if service is None: if 'name' in user_options: self.server_name = user_options['name'] if hasattr(self, 'container_spec') and self.container_spec is not None: container_spec = dict(**self.container_spec) elif user_options == {}: raise ("A container_spec is needed in to create a service") container_spec.update(user_options.get('container_spec', {})) # iterates over mounts to create # a new mounts list of docker.types.Mount container_spec['mounts'] = [] for mount in self.container_spec['mounts']: m = dict(**mount) if 'source' in m: m['source'] = m['source'].format( username=self.service_owner) if 'driver_config' in m: device = m['driver_config']['options']['device'].format( username=self.service_owner) m['driver_config']['options']['device'] = device m['driver_config'] = docker.types.DriverConfig( **m['driver_config']) container_spec['mounts'].append(docker.types.Mount(**m)) # some Envs are required by the single-user-image container_spec['env'] = self.get_env() if hasattr(self, 'resource_spec'): resource_spec = dict(**self.resource_spec) resource_spec.update(user_options.get('resource_spec', {})) # enable to set a human readable memory unit if 'mem_limit' in resource_spec: resource_spec['mem_limit'] = parse_bytes( resource_spec['mem_limit']) if 'mem_reservation' in resource_spec: resource_spec['mem_reservation'] = parse_bytes( resource_spec['mem_reservation']) if hasattr(self, 'networks'): networks = self.networks if user_options.get('networks') is not None: networks = user_options.get('networks') if hasattr(self, 'placement'): placement = self.placement if user_options.get('placement') is not None: placement = user_options.get('placement') image = container_spec['Image'] del container_spec['Image'] # create the service container_spec = docker.types.ContainerSpec( image, **container_spec) resources = docker.types.Resources(**resource_spec) task_spec = { 'container_spec': container_spec, 'resources': resources, 'placement': placement } task_tmpl = docker.types.TaskTemplate(**task_spec) resp = yield self.docker('create_service', task_tmpl, name=self.service_name, networks=networks) self.service_id = resp['ID'] self.log.info("Created Docker service '%s' (id: %s) from image %s", self.service_name, self.service_id[:7], image) else: self.log.info("Found existing Docker service '%s' (id: %s)", self.service_name, self.service_id[:7]) # Handle re-using API token. # Get the API token from the environment variables # of the running service: envs = service['Spec']['TaskTemplate']['ContainerSpec']['Env'] for line in envs: if line.startswith('JPY_API_TOKEN='): self.api_token = line.split('=', 1)[1] break ip = self.service_name port = self.service_port # we use service_name instead of ip # https://docs.docker.com/engine/swarm/networking/#use-swarm-mode-service-discovery # service_port is actually equal to 8888 return (ip, port) @gen.coroutine def stop(self, now=False): """Stop and remove the service Consider using stop/start when Docker adds support """ self.log.info("Stopping and removing Docker service %s (id: %s)", self.service_name, self.service_id[:7]) yield self.docker('remove_service', self.service_id[:7]) self.log.info("Docker service %s (id: %s) removed", self.service_name, self.service_id[:7]) self.clear_state()