Ejemplo n.º 1
0
class VisibilityPipe(Pipe):

    observes = TypedTuple(
        T.Unicode(),
        default_value=(
            F.AnyHidden,
            F.Layout,
        ),
    )

    @T.default("reports")
    def _default_reports(self):
        return (F.Layout, )

    async def run(self):
        if self.outlet is None or self.inlet is None:
            return

        root = self.inlet.index.root
        # generate an index of hidden elements
        vis_index = VisIndex.from_els(root)

        # clear old slack css classes from elements
        vis_index.clear_slack(root)

        # serialize the elements excluding hidden
        with exclude_hidden, exclude_layout:
            data = root.dict()

        # new root node with slack edges / ports introduced due to hidden
        # elements
        with Registry():
            value = convert_elkjson(data, vis_index)

            for el in index.iter_elements(value):
                el.id = el.get_id()
        self.outlet.value = value

        return self.outlet
Ejemplo n.º 2
0
class InputSlot(ipywidgets.Widget):
    _model_name = traitlets.Unicode("ReteInputModel").tag(sync=True)
    _model_module = traitlets.Unicode("jupyterlab_nodeeditor").tag(sync=True)
    _model_module_version = traitlets.Unicode(EXTENSION_VERSION).tag(sync=True)

    key = traitlets.Unicode().tag(sync=True)
    title = traitlets.Unicode().tag(sync=True)
    socket_type = traitlets.Unicode().tag(sync=True)
    sockets = traitlets.Instance(SocketCollection).tag(
        sync=True, **ipywidgets.widget_serialization)

    def _ipython_display_(self):
        display(self.widget())

    def widget(self):
        return ipywidgets.Label(
            f"Slot {self.key}: {self.title} ({self.socket_type})")
Ejemplo n.º 3
0
class NodeInstanceModel(ipywidgets.Widget):
    _model_name = traitlets.Unicode("ReteNodeModel").tag(sync=True)
    _model_module = traitlets.Unicode("jupyterlab_nodeeditor").tag(sync=True)
    _model_module_version = traitlets.Unicode(EXTENSION_VERSION).tag(sync=True)
    _view_name = traitlets.Unicode("ReteNodeView").tag(sync=True)
    _view_module = traitlets.Unicode("jupyterlab_nodeeditor").tag(sync=True)
    _view_module_version = traitlets.Unicode(EXTENSION_VERSION).tag(sync=True)
    title = traitlets.Unicode("Title").tag(sync=True)
    # We distinguish between name and title because one is displayed on all
    # instances and the other is the name of the component type
    type_name = traitlets.Unicode("DefaultComponent",
                                  allow_none=False).tag(sync=True)
    inputs = traitlets.List(InputSlotTrait()).tag(
        sync=True, **ipywidgets.widget_serialization)
    outputs = traitlets.List(OutputSlotTrait()).tag(
        sync=True, **ipywidgets.widget_serialization)
    display_element = traitlets.Instance(ipywidgets.VBox)

    @traitlets.default("display_element")
    def _default_display_element(self):
        def _update_inputs(event):
            input_box.children = [ipywidgets.Label("Inputs")
                                  ] + [slot.widget() for slot in self.inputs]

        def _update_outputs(event):
            output_box.children = [ipywidgets.Label("Outputs")] + [
                slot.widget() for slot in self.outputs
            ]

        label = ipywidgets.Label()
        traitlets.link((self, "title"), (label, "value"))
        self.observe(_update_inputs, ["inputs"])
        self.observe(_update_outputs, ["outputs"])
        input_box = ipywidgets.VBox([ipywidgets.Label("Inputs")])
        output_box = ipywidgets.VBox([ipywidgets.Label("Outputs")])
        return ipywidgets.VBox([label, input_box, output_box])
Ejemplo n.º 4
0
class Base(WXYZBase):
    """Utility traitlets, primarily based around
    - development convenience
    - ipywidgets conventions
    - integration with wxyz.lab.DockBox, mostly lumino Widget.label attrs
    """

    _model_module = T.Unicode(module_name).tag(sync=True)
    _model_module_version = T.Unicode(module_version).tag(sync=True)
    _view_module = T.Unicode(module_name).tag(sync=True)
    _view_module_version = T.Unicode(module_version).tag(sync=True)

    error = T.CUnicode("").tag(sync=True)  # type: str
    description = T.Unicode("An Undescribed Widget").tag(
        sync=True)  # type: str
    icon_class = T.Unicode("jp-CircleIcon").tag(sync=True)  # type: str
    closable = T.Bool(default_value=True).tag(sync=True)  # type: bool
Ejemplo n.º 5
0
class LayoutAlgorithm(LayoutOptionWidget):
    """Select a specific layout algorithm.

    https://www.eclipse.org/elk/reference/options/org-eclipse-elk-algorithm.html
    """

    identifier = "org.eclipse.elk.algorithm"

    value = T.Enum(
        values=list(ALGORITHM_OPTIONS.keys()), default_value=ELKLayered.identifier
    )
    metadata_provider = T.Unicode()
    applies_to = ["parents"]

    def _ui(self) -> List[W.Widget]:
        options = [
            (_cls.title, identifier) for (identifier, _cls) in ALGORITHM_OPTIONS.items()
        ]
        dropdown = W.Dropdown(description="Layout Algorithm", options=options)

        T.link((self, "value"), (dropdown, "value"))

        return [dropdown]

    @T.default("metadata_provider")
    def _default_metadata_provider(self):
        """Default value for the current metadata provider"""
        return self._update_metadata_provider()

    @T.observe("value")
    def _update_metadata_provider(self, change: T.Bunch = None):
        """Change Handler to update the metadata provider based on current
        selected algorithm
        """

        provider = ALGORITHM_OPTIONS[self.value].metadata_provider
        self.metadata_provider = provider
        return provider
Ejemplo n.º 6
0
class Tool(W.Widget):
    tee: Pipe = T.Instance(Pipe, allow_none=True).tag(sync=True,
                                                      **W.widget_serialization)
    on_done = T.Any(allow_none=True)  # callback when done
    disable = T.Bool(default_value=False).tag(sync=True,
                                              **W.widget_serialization)
    reports = TypedTuple(T.Unicode(), kw={})
    _task: asyncio.Future = None
    ui = T.Instance(W.DOMWidget, allow_none=True)
    priority = T.Int(default_value=10)

    def handler(self, *args):
        """Handler callback for running the tool"""
        # canel old work if needed
        if self._task:
            self._task.cancel()

        # schedule work
        self._task = asyncio.create_task(self.run())

        # callback
        self._task.add_done_callback(self._finished)

        if self.tee:
            self.tee.inlet.flow = self.reports

    async def run(self):
        raise NotImplementedError()

    def _finished(self, future: asyncio.Future):
        try:
            future.result()
            if callable(self.on_done):
                self.on_done()
        except asyncio.CancelledError:
            pass  # cancellation should not log an error
        except Exception:
            self.log.exception(f"Error running tool: {type(self)}")
Ejemplo n.º 7
0
class Template(nowidget.Display):
    body = traitlets.Unicode()
    template = traitlets.Any()
    environment = traitlets.Any()
    globals = traitlets.Dict()

    def __init__(self, body: str, **kwargs):
        super().__init__(body=body, **kwargs)
        if self.parent:
            if not self.parent.has_trait('display_manager'):
                nowidget.manager.load_ipython_extension(self.parent)
            self.parent.display_manager.append(self)

    @traitlets.default('template')
    def _default_template(self):
        return self.environment.from_string(self.body)

    @traitlets.default('environment')
    def _default_environment(self):
        return environment()

    @traitlets.default('vars')
    def _default_vars(self):
        return jinja2.meta.find_undeclared_variables(
            self.template.environment.parse(self.body))

    def render(self, **kwargs):
        return self.template.render({
            **(kwargs or self.parent.user_ns),
            **self.globals
        })

    def main(self, **kwargs):
        return IPython.display.Markdown(
            self.render(**{
                **(kwargs or self.parent.user_ns),
                **self.globals
            }))
Ejemplo n.º 8
0
class Worker(traitlets.HasTraits):
    node = traitlets.Unicode(allow_none=True)
    time_left = traitlets.Integer(default_value=0)
    is_working = traitlets.Bool(default_value=False)

    def __eq__(self, other):
        return self.time_left == other.time_left

    def __lt__(self, other):
        return self.time_left < other.time_left

    def tick(self, time):
        self.time_left -= time
        if self.time_left <= 0:
            self.is_working = False
            completed_nodes.add(self.node)
            ordered_completed_nodes.append(self.node)
            self.node = None

    def assign_task(self, node):
        self.node = node
        self.time_left = node_costs[node]
        self.is_working = True
Ejemplo n.º 9
0
class MarkElementWidget(W.DOMWidget):
    value: Node = T.Instance(Node, allow_none=True).tag(sync=True,
                                                        **elk_serialization)
    index: MarkIndex = T.Instance(MarkIndex,
                                  kw={}).tag(sync=True,
                                             **W.widget_serialization)
    flow: Tuple[str] = TypedTuple(T.Unicode(), kw={}).tag(sync=True)

    def persist(self):
        if self.index.elements is None:
            self.build_index()
        else:
            self.index.elements.update(ElementIndex.from_els(self.value))
        return self

    def build_index(self) -> MarkIndex:
        if self.value is None:
            index = ElementIndex()
        else:
            with self.index.context:
                index = ElementIndex.from_els(self.value)
        self.index.elements = index
        return self.index
Ejemplo n.º 10
0
class TerrainTilesComposite(S3Mixin, TileCompositor):
    urls = tl.List(trait=tl.Unicode()).tag(attr=True)
    anon = tl.Bool(True)

    _repr_keys = ["urls"]

    @cached_property
    def sources(self):
        return [self._create_source(url) for url in self.urls]

    def get_coordinates(self):
        return podpac.coordinates.union(
            [source.coordinates for source in self.sources])

    def _create_source(self, url):
        return TerrainTilesSource(
            source=url,
            cache_ctrl=self.cache_ctrl,
            force_eval=self.force_eval,
            cache_output=self.cache_output,
            cache_dataset=True,
            s3=self.s3,
        )
Ejemplo n.º 11
0
class LayoutOptionWidget(W.VBox):
    identifier: Hashable = None
    metadata_provider: str = None
    applies_to: List[ElkGraphElement] = None
    group: str = None
    title: str = None  # optional title for UI purposes

    value = T.Unicode()

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._update_value()

    def _ipython_display_(self, **kwargs):
        if not self.children:
            self.children = self._ui()
        super()._ipython_display_(**kwargs)

    def _ui(self) -> List[W.Widget]:
        raise NotImplementedError(
            "Subclasses should implement their specific UI Controls")

    def _update_value(self):
        pass  # expecting subclasses to override

    @classmethod
    def matches(cls, elk_type: Type[ElkGraphElement]):
        """Checks if this LayoutOption applies to given ElkGraphElement type"""
        if cls.applies_to is None:
            return False
        if isinstance(cls.applies_to, (tuple, list)):
            is_valid = elk_type in cls.applies_to
            # if not is_valid and elk_type is ElkNode:
            #     # options with applies_to "parents" should match ElkNode
            #     is_valid = cls.matches("parents")
            return is_valid
        return elk_type == cls.applies_to
Ejemplo n.º 12
0
class RTSPCamera(Camera):

    capture_fps = traitlets.Integer(default_value=15)
    capture_width = traitlets.Integer(default_value=640)
    capture_height = traitlets.Integer(default_value=480)
    capture_device = traitlets.Unicode(
        default_value='rtsp://Your_IP_Cam_Address')

    def __init__(self, *args, **kwargs):
        super(RTSPCamera, self).__init__(*args, **kwargs)
        try:
            self.cap = cv2.VideoCapture(self.capture_device, cv2.CAP_FFMPEG)

            re, image = self.cap.read()

            if not re:
                raise RuntimeError(
                    'Could not read image from camera, phase 1.')

        except:
            raise RuntimeError(
                'Could not initialize camera.  Please see error trace, phase 2.'
            )

        atexit.register(self.cap.release)

    def _gst_str(self):
        return (self.capture_device, self.capture_width, self.capture_height)

    def _read(self):
        re, image = self.cap.read()
        if re:
            image_resized = cv2.resize(image,
                                       (int(self.width), int(self.height)))
            return image_resized
        else:
            raise RuntimeError('Could not read image from camera, phase 3')
Ejemplo n.º 13
0
class YearSubstituteCoordinates(ModifyCoordinates):
    year = tl.Unicode().tag(attr=True)

    # Remove tags from attributes
    lat = tl.List()
    lon = tl.List()
    time = tl.List()
    alt = tl.List()
    coordinates_source = None

    def get_modified_coordinates1d(self, coord, dim):
        """
        Get the desired 1d coordinates for the given dimension, depending on the selection attr for the given
        dimension::

        Parameters
        ----------
        coords : Coordinates
            The requested input coordinates
        dim : str
            Dimension for doing the selection

        Returns
        -------
        coords1d : ArrayCoordinates1d
            The selected coordinates for the given dimension.
        """
        if dim != "time":
            return coord[dim]
        times = coord["time"]
        delta = np.datetime64(self.year)
        new_times = [
            add_coord(c, delta - c.astype("datetime64[Y]"))
            for c in times.coordinates
        ]

        return ArrayCoordinates1d(new_times, name="time")
Ejemplo n.º 14
0
class SortField(BaseObject):
    """Wrapper for Vega-Lite SortField definition.
    
    Attributes
    ----------
    field: Unicode
        The field name to aggregate over.
    op: AggregateOp
        The sort aggregation operator.
    order: SortOrder
        
    """
    field = T.Unicode(allow_none=True,
                      default_value=None,
                      help="""The field name to aggregate over.""")
    op = AggregateOp(allow_none=True,
                     default_value=None,
                     help="""The sort aggregation operator.""")
    order = SortOrder(allow_none=True, default_value=None)

    def __init__(self, field=None, op=None, order=None, **kwargs):
        kwds = dict(field=field, op=op, order=order)
        kwargs.update({k: v for k, v in kwds.items() if v is not None})
        super(SortField, self).__init__(**kwargs)
Ejemplo n.º 15
0
class CssColor(Color):

    name = traitlets.Unicode()
    alpha = traitlets.Float(min=0., max=1., allow_none=True)

    def __init__(self, name, alpha=None):
        self.name = name
        self.alpha = alpha

    def __repr__(self):
        if self.alpha is None:
            rep = """Color.fromCssColorString("{name}")"""
            return rep.format(name=self.name)
        else:
            rep = """Color.fromCssColorString("{name}").withAlpha({alpha})"""
            return rep.format(name=self.name, alpha=self.alpha)

    @property
    def script(self):
        # no need new
        return 'Cesium.{rep}'.format(rep=repr(self))

    def copy(self):
        return self.__class__(name=self.name, alpha=self.alpha)
Ejemplo n.º 16
0
class PortAnchorOffset(LayoutOptionWidget):
    """The offset to the port position where connections shall be attached.

    https://www.eclipse.org/elk/reference/options/org-eclipse-elk-port-anchor.html
    """

    identifier = "org.eclipse.elk.port.anchor"
    metadata_provider = "core.options.CoreOptions"
    applies_to = [ElkPort]
    group = "port"

    x = T.Int(default_value=0)
    y = T.Int(default_value=0)
    value = T.Unicode(allow_none=True)
    active = T.Bool(default_value=False)

    def _ui(self) -> List[W.Widget]:
        cb = W.Checkbox(description="Active")
        x_slider = W.IntSlider(description="Width")
        y_slider = W.IntSlider(description="Height")

        T.link((self, "active"), (cb, "value"))
        T.link((self, "x"), (x_slider, "value"))
        T.link((self, "y"), (y_slider, "value"))
        return [
            cb,
            x_slider,
            y_slider,
        ]

    @T.observe("x", "y")
    def _update_value(self, change=None):
        if self.active:
            self.value = f"({self.x}, {self.y})"
        else:
            self.value = None
Ejemplo n.º 17
0
class RTSPCamera(Camera):

    capture_fps = traitlets.Integer(default_value=30)
    capture_width = traitlets.Integer(default_value=640)
    capture_height = traitlets.Integer(default_value=480)
    capture_device = traitlets.Unicode(
        default_value='rtsp://Your_IP_Cam_Address')

    def __init__(self, *args, **kwargs):
        super(RTSPCamera, self).__init__(*args, **kwargs)
        try:
            self.cap = cv2.VideoCapture(self._gst_str(), cv2.CAP_GSTREAMER)

            re, image = self.cap.read()

            if not re:
                raise RuntimeError('Could not read image from camera.')

        except:
            raise RuntimeError(
                'Could not initialize camera.  Please see error trace.')

        atexit.register(self.cap.release)

    def _gst_str(self):
        return 'rtspsrc location={} ! decodebin ! nvvidconv ! video/x-raw, width=(int){}, height=(int){}, format=(string)BGRx ! videoconvert ! appsink'.format(
            self.capture_device, self.capture_width, self.capture_height)

    def _read(self):
        re, image = self.cap.read()
        if re:
            image_resized = cv2.resize(image,
                                       (int(self.width), int(self.height)))
            return image_resized
        else:
            raise RuntimeError('Could not read image from camera')
Ejemplo n.º 18
0
class ElkJS(SyncedPipe):
    """Jupyterlab widget for calling `elkjs <https://github.com/kieler/elkjs>`_
    layout given a valid elkjson dictionary"""

    _model_name = T.Unicode("ELKLayoutModel").tag(sync=True)
    _model_module = T.Unicode(EXTENSION_NAME).tag(sync=True)
    _model_module_version = T.Unicode(EXTENSION_SPEC_VERSION).tag(sync=True)
    _view_module = T.Unicode(EXTENSION_NAME).tag(sync=True)

    observes = TypedTuple(T.Unicode(), default_value=(F.Anythinglayout, ))
    reports = TypedTuple(T.Unicode(), default_value=(F.Layout, ))

    async def run(self):
        # watch once
        if self.outlet is None:
            return

        # signal to browser and wait for done
        future_value = wait_for_change(self.outlet, "value")
        self.send({"action": "run"})

        # wait to return until
        await future_value
        self.outlet.persist()
Ejemplo n.º 19
0
class QueryConstructor(W.HBox):
    """TODO
    - way better templating and more efficient formatting
    - replace individual observers with larger observer
    - move build_query to standalone function
    """

    convert_arrow = T.Instance(W.Image)
    query_input = T.Instance(W.VBox)
    formatted_query = T.Instance(W.Textarea)

    # traits from children
    namespaces = T.Unicode()
    query_type = T.Unicode(default_value="SELECT")
    query_line = T.Unicode(allow_none=True)
    query_body = T.Unicode()

    log = W.Output()

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.query_input = QueryInput()

        # Inherit traits with links TODO easier way?
        T.link((self.query_input.namespaces, "namespaces"), (self, "namespaces"))
        T.link((self.query_input.header, "dropdown_value"), (self, "query_type"))
        T.link((self.query_input.header, "header_value"), (self, "query_line"))
        T.link((self.query_input.body.body, "value"), (self, "query_body"))

        self.children = tuple([self.query_input, self.formatted_query])

    @log.capture()
    def build_query(self):
        # get values TODO improve
        namespaces = self.namespaces
        query_type = self.query_type
        query_line = self.query_line
        query_body = self.query_body or self.query_input.body.body.placeholder

        # update query_body
        query_body = "\t\n".join(
            query_body.split("\n")
        )  # TODO this isn't actually formatting properly

        header_str = ""
        # TODO move these to module vars
        if query_type in {"SELECT", "SELECT DISTINCT"}:
            if query_line == "":
                query_line = "*"
            header_str = f"{query_type} {query_line}"
        elif query_type == "ASK":
            header_str = query_type
        elif query_type == "CONSTRUCT":
            if query_line == "":
                query_line = "{?s ?p ?o}"
            header_str = f"{query_type} {query_line}"
        else:
            with self.log:
                raise ValueError(f"Unexpected query type: {query_type}")

        return query_template.format(
            namespaces,
            header_str,
            query_body,
        )

    @T.default("formatted_query")
    def make_default_formatted_query(self):
        formatted_query = W.Textarea(
            placeholder="Formatted query will appear here...",
            layout=W.Layout(height="260px", width="50%"),
        )
        return formatted_query

    @T.observe(
        "namespaces",
        "query_type",
        "query_line",
        "query_body",
    )
    def update_query(self, change):
        self.formatted_query.value = self.build_query()
Ejemplo n.º 20
0
class Widget(DOMWidget):
    _view_name = traitlets.Unicode('ReboundView').tag(sync=True)
    _view_module = traitlets.Unicode('rebound').tag(sync=True)
    count = traitlets.Int(0).tag(sync=True)
    screenshotcount = traitlets.Int(0).tag(sync=True)
    t = traitlets.Float().tag(sync=True)
    N = traitlets.Int().tag(sync=True)
    overlay = traitlets.Unicode('REB WIdund').tag(sync=True)
    width = traitlets.Float().tag(sync=True)
    height = traitlets.Float().tag(sync=True)
    scale = traitlets.Float().tag(sync=True)
    particle_data = traitlets.CBytes(allow_none=True).tag(sync=True)
    orbit_data = traitlets.CBytes(allow_none=True).tag(sync=True)
    orientation = traitlets.Tuple().tag(sync=True)
    orbits = traitlets.Int().tag(sync=True)
    screenshot = traitlets.Unicode().tag(sync=True)

    def __init__(self,
                 simulation,
                 size=(200, 200),
                 orientation=(0., 0., 0., 1.),
                 scale=None,
                 autorefresh=True,
                 orbits=True,
                 overlay=True):
        """ 
        Initializes a Widget.

        Widgets provide real-time 3D interactive visualizations for REBOUND simulations 
        within Jupyter Notebooks. To use widgets, the ipywidgets package needs to be installed
        and enabled in your Jupyter notebook server. 

        Parameters
        ----------
        size : (int, int), optional
            Specify the size of the widget in pixels. The default is 200 times 200 pixels.
        orientation : (float, float, float, float), optional
            Specify the initial orientation of the view. The four floats correspond to the
            x, y, z, and w components of a quaternion. The quaternion will be normalized.
        scale : float, optional
            Set the initial scale of the view. If not set, the widget will determine the 
            scale automatically based on current particle positions.
        autorefresh : bool, optional
            The default value if True. The view is updated whenever a particle is added,
            removed and every 100th of a second while a simulation is running. If set 
            to False, then the user needs to manually call the refresh() function on the 
            widget. This might be useful if performance is an issue.
        orbits : bool, optional
            The default value for this is True and the widget will draw the instantaneous 
            orbits of the particles. For simulations in which particles are not on
            Keplerian orbits, the orbits shown will not be accurate. 
        overlay : string, optional
            Change the default text overlay. Set to None to hide all text.
        """
        self.screenshotcountall = 0
        self.width, self.height = size
        self.t, self.N = simulation.t, simulation.N
        self.orientation = orientation
        self.autorefresh = autorefresh
        self.orbits = orbits
        self.useroverlay = overlay
        self.simp = pointer(simulation)
        clibrebound.reb_display_copy_data.restype = c_int
        if scale is None:
            self.scale = simulation.display_data.contents.scale
        else:
            self.scale = scale
        self.count += 1

        super(Widget, self).__init__()

    def refresh(self, simp=None, isauto=0):
        """ 
        Manually refreshes a widget.
        
        Note that this function can also be called using the wrapper function of
        the Simulation object: sim.refreshWidgets(). 
        """

        if simp is None:
            simp = self.simp
        if self.autorefresh == 0 and isauto == 1:
            return
        sim = simp.contents
        size_changed = clibrebound.reb_display_copy_data(simp)
        clibrebound.reb_display_prepare_data(simp, c_int(self.orbits))
        if sim.N > 0:
            self.particle_data = (c_char * (4 * 7 * sim.N)).from_address(
                sim.display_data.contents.particle_data).raw
            if self.orbits:
                self.orbit_data = (
                    c_char * (4 * 9 * (sim.N - 1))).from_address(
                        sim.display_data.contents.orbit_data).raw
        if size_changed:
            #TODO: Implement better GPU size change
            pass
        if self.useroverlay == True:
            self.overlay = "REBOUND (%s), N=%d, t=%g" % (sim.integrator, sim.N,
                                                         sim.t)
        elif self.useroverlay is None or self.useroverlay == False:
            self.overlay = ""
        else:
            self.overlay = self.useroverlay + ", N=%d, t=%g" % (sim.N, sim.t)
        self.N = sim.N
        self.t = sim.t
        self.count += 1

    def takeScreenshot(self,
                       times=None,
                       prefix="./screenshot",
                       resetCounter=False,
                       archive=None,
                       mode="snapshot"):
        """
        Take one or more screenshots of the widget and save the images to a file. 
        The images can be used to create a video.

        This function cannot be called multiple times within one cell.

        Note: this is a new feature and might not work on all systems.
        It was tested on python 2.7.10 and 3.5.2 on MacOSX.
        
        Parameters
        ----------
        times : (float, list), optional
            If this argument is not given a screenshot of the widget will be made 
            as it is (without integrating the simulation). If a float is given, then the
            simulation will be integrated to that time and then a screenshot will 
            be taken. If a list of floats is given, the simulation will be integrated
            to each time specified in the array. A separate screenshot for 
            each time will be saved.
        prefix : (str), optional
            This string will be part of the output filename for each image.
            Follow by a five digit integer and the suffix .png. By default the
            prefix is './screenshot' which outputs images in the current
            directory with the filnames screenshot00000.png, screenshot00001.png...
            Note that the prefix can include a directory.
        resetCounter : (bool), optional
            Resets the output counter to 0. 
        archive : (rebound.SimulationArchive), optional
            Use a REBOUND SimulationArchive. Thus, instead of integratating the 
            Simulation from the current time, it will use the SimulationArchive
            to load a snapshot. See examples for usage.
        mode : (string), optional
            Mode to use when querying the SimulationArchive. See SimulationArchive
            documentation for details. By default the value is "snapshot".

        Examples
        --------

        First, create a simulation and widget. All of the following can go in 
        one cell.

        >>> sim = rebound.Simulation()
        >>> sim.add(m=1.)
        >>> sim.add(m=1.e-3,x=1.,vy=1.)
        >>> w = sim.getWidget()
        >>> w

        The widget should show up. To take a screenshot, simply call 

        >>> w.takeScreenshot()

        A new file with the name screenshot00000.png will appear in the 
        current directory. 

        Note that the takeScreenshot command needs to be in a separate cell, 
        i.e. after you see the widget. 

        You can pass an array of times to the function. This allows you to 
        take multiple screenshots, for example to create a movie, 

        >>> times = [0,10,100]
        >>> w.takeScreenshot(times)

        """
        self.archive = archive
        if resetCounter:
            self.screenshotcountall = 0
        self.screenshotprefix = prefix
        self.screenshotcount = 0
        self.overlay = "REBOUND"
        self.screenshot = ""
        if archive is None:
            if times is None:
                times = self.simp.contents.t
            try:
                # List
                len(times)
            except:
                # Float:
                times = [times]
            self.times = times
            self.observe(savescreenshot, names="screenshot")
            self.simp.contents.integrate(times[0])
            self.screenshotcount += 1  # triggers first screenshot
        else:
            if times is None:
                raise ValueError("Need times argument for archive mode.")
            try:
                len(times)
            except:
                raise ValueError("Need a list of times for archive mode.")
            self.times = times
            self.mode = mode
            self.observe(savescreenshot, names="screenshot")
            sim = archive.getSimulation(times[0], mode=mode)
            self.refresh(pointer(sim))
            self.screenshotcount += 1  # triggers first screenshot

    @staticmethod
    def getClientCode():
        return shader_code + js_code
Ejemplo n.º 21
0
class DropdownButton(widgets.VBox):

    options = traitlets.List(trait=traitlets.Unicode(),
                             default_value=list(),
                             allow_none=True)
    submission_functions = traitlets.List(default_value=list(),
                                          allow_none=True)

    def __init__(self, options: Sequence[str], *args, **kwargs):
        """Create a dropdown button.

        Parameters
        ----------
        options : Sequence[str]
            The options to display in the widget.
        """

        super().__init__(*args, **kwargs)

        self.options = options

        self.dropdown = widgets.Dropdown(
            options=[str(option) for option in self.options],
            description="Label:",
        )
        widgets.dlink((self, "options"), (self.dropdown, "options"))
        self.dropdown.observe(self._change_selection)

        self.button = widgets.Button(
            description="Submit.",
            tooltip="Submit label.",
            button_style="success",
        )
        self.button.on_click(self._handle_click)

        self.hints: DefaultDict[str,
                                widgets.Output] = defaultdict(widgets.Output)

        self.children = [
            widgets.HBox([self.dropdown, self.button]),
            self.hints[self.dropdown.value],
        ]

    def on_click(self, func: Callable) -> None:
        """Add a function to the list of calls made after a click.

        Parameters
        ----------
        func : Callable
            The function to call when the button is clicked.
        """
        if not callable(func):
            raise ValueError(
                "You need to provide a callable object, but you provided " +
                str(func) + ".")
        self.submission_functions.append(func)

    def _handle_click(self, owner: widgets.Button) -> None:
        for func in self.submission_functions:
            func(owner)

    def _change_selection(self, change=None):
        if self.dropdown.value is not None:
            self.button.description = self.dropdown.value
            self.button.disabled = False
        else:
            self.button.description = "Submit."
            self.button.disabled = True

        self.children = [
            widgets.HBox([self.dropdown, self.button]),
            self.hints[self.dropdown.value],
        ]

    @traitlets.validate("options")
    def _check_options(self, proposal):
        seen = set()
        return [x for x in proposal["value"] if not (x in seen or seen.add(x))]
Ejemplo n.º 22
0
class _BqplotMixin(traitlets.HasTraits):
    x_min = traitlets.CFloat()
    x_max = traitlets.CFloat()
    y_min = traitlets.CFloat(None, allow_none=True)
    y_max = traitlets.CFloat(None, allow_none=True)
    x_label = traitlets.Unicode()
    y_label = traitlets.Unicode()
    tool = traitlets.Unicode(None, allow_none=True)

    def __init__(self, zoom_y=True, **kwargs):
        super().__init__(**kwargs)
        self.x_scale = bqplot.LinearScale(allow_padding=False)
        self.y_scale = bqplot.LinearScale(allow_padding=False)
        widgets.link((self, 'x_min'), (self.x_scale, 'min'))
        widgets.link((self, 'x_max'), (self.x_scale, 'max'))
        widgets.link((self, 'y_min'), (self.y_scale, 'min'))
        widgets.link((self, 'y_max'), (self.y_scale, 'max'))

        self.x_axis = bqplot.Axis(scale=self.x_scale)
        self.y_axis = bqplot.Axis(scale=self.y_scale, orientation='vertical')
        widgets.link((self, 'x_label'), (self.x_axis, 'label'))
        widgets.link((self, 'y_label'), (self.y_axis, 'label'))
        self.x_axis.color = blackish
        self.y_axis.color = blackish
        self.x_axis.label_color = blackish
        self.y_axis.label_color = blackish
        # self.y_axis.tick_style = {'fill': blackish, 'stroke':'none'}
        self.y_axis.grid_color = blackish
        self.x_axis.grid_color = blackish
        self.x_axis.label_offset = "2em"
        self.y_axis.label_offset = "3em"
        self.x_axis.grid_lines = 'none'
        self.y_axis.grid_lines = 'none'

        self.axes = [self.x_axis, self.y_axis]
        self.scales = {'x': self.x_scale, 'y': self.y_scale}

        self.figure = bqplot.Figure(axes=self.axes)
        self.figure.background_style = {'fill': 'none'}
        self.figure.padding_y = 0
        self.figure.fig_margin = {'bottom': 40, 'left': 60, 'right': 10, 'top': 10}

        self.interacts = {}
        self.interacts['pan-zoom'] = bqplot.PanZoom(scales={'x': [self.x_scale], 'y': [self.y_scale] if zoom_y else []})
        self.interacts['select-rect'] = bqplot.interacts.BrushSelector(x_scale=self.x_scale, y_scale=self.y_scale, color="green")
        self.interacts['select-x'] = bqplot.interacts.BrushIntervalSelector(scale=self.x_scale, color="green")
        self._brush = self.interacts['select-rect']
        self._brush_interval = self.interacts['select-x']

        # TODO: put the debounce in the presenter?
        @vaex.jupyter.debounced(DEBOUNCE_SELECT)
        def update_brush(*args):
            with self.output:
                if not self._brush.brushing:  # if we ended _brushing, reset it
                    self.figure.interaction = None
                if self._brush.selected is not None:
                    x1, x2 = self._brush.selected_x
                    y1, y2 = self._brush.selected_y
                    # (x1, y1), (x2, y2) = self._brush.selected
                    # mode = self.modes_names[self.modes_labels.index(self.button_selection_mode.value)]
                    self.presenter.select_rectangle(x1, x2, y1, y2)
                else:
                    self.presenter.select_nothing()
                if not self._brush.brushing:  # but then put it back again so the rectangle is gone,
                    self.figure.interaction = self._brush

        self._brush.observe(update_brush, ["selected", "selected_x"])

        @vaex.jupyter.debounced(DEBOUNCE_SELECT)
        def update_brush(*args):
            with self.output:
                if not self._brush_interval.brushing:  # if we ended _brushing, reset it
                    self.figure.interaction = None
                if self._brush_interval.selected is not None and len(self._brush_interval.selected):
                    x1, x2 = self._brush_interval.selected
                    self.presenter.select_x_range(x1, x2)
                else:
                    self.presenter.select_nothing()
                if not self._brush_interval.brushing:  # but then put it back again so the rectangle is gone,
                    self.figure.interaction = self._brush_interval

        self._brush_interval.observe(update_brush, ["selected"])

        def tool_change(change=None):
            self.figure.interaction = self.interacts.get(self.tool, None)
        self.observe(tool_change, 'tool')
        self.widget = self.figure
Ejemplo n.º 23
0
class Parameters(traitlets.HasTraits):
    """The physical and computational parameters are built on top of `traitlets`_.

    It is a framework that lets Python classes have attributes with type checking,
    dynamically calculated default values, and ‘on change’ callbacks.
    In addition, there are `ipywidgets`_ for a friendly user interface.

    .. warning:: There is a bug reported affecting the widgets `#2`_,
        they are not working properly at the moment.

    .. _traitlets:
        https://traitlets.readthedocs.io/en/stable/index.html
    .. _ipywidgets:
        https://ipywidgets.readthedocs.io/en/latest/
    .. _#2:
        https://github.com/fschuch/xcompact3d_toolbox/issues/2

    Attributes
    ----------
    nclx1 : int
        Boundary condition for velocity field where :math:`x=0`, the options are:

        * 0 - Periodic;
        * 1 - Free-slip;
        * 2 - Inflow.

    nclxn : int
        Boundary condition for velocity field where :math:`x=L_x`, the options are:

        * 0 - Periodic;
        * 1 - Free-slip;
        * 2 - Convective outflow.

    ncly1 : int
        Boundary condition for velocity field where :math:`y=0`, the options are:

        * 0 - Periodic;
        * 1 - Free-slip;
        * 2 - No-slip.

    nclyn : int
        Boundary condition for velocity field where :math:`y=L_y`, the options are:

        * 0 - Periodic;
        * 1 - Free-slip;
        * 2 - No-slip.

    nclz1 : int
        Boundary condition for velocity field where :math:`z=0`, the options are:

        * 0 - Periodic;
        * 1 - Free-slip;
        * 2 - No-slip.

    nclzn : int
        Boundary condition for velocity field where :math:`z=L_z`, the options are:

        * 0 - Periodic;
        * 1 - Free-slip;
        * 2 - No-slip.

    nclxS1 : int
        Boundary condition for scalar field(s) where :math:`x=0`, the options are:

        * 0 - Periodic;
        * 1 - No-flux;
        * 2 - Inflow.

    nclxSn : int
        Boundary condition for scalar field(s) where :math:`x=L_x`, the options are:

        * 0 - Periodic;
        * 1 - No-flux;
        * 2 - Convective outflow.

    nclyS1 : int
        Boundary condition for scalar field(s) where :math:`y=0`, the options are:

        * 0 - Periodic;
        * 1 - No-flux;
        * 2 - Dirichlet.

    nclySn : int
        Boundary condition for scalar field(s) where :math:`y=L_y`, the options are:

        * 0 - Periodic;
        * 1 - No-flux;
        * 2 - Dirichlet.

    nclzS1 : int
        Boundary condition for scalar field(s) where :math:`z=0`, the options are:

        * 0 - Periodic;
        * 1 - No-flux;
        * 2 - Dirichlet.

    nclzSn : int
        Boundary condition for scalar field(s) where :math:`z=L_z`, the options are:

        * 0 - Periodic;
        * 1 - No-flux;
        * 2 - Dirichlet.

    ivisu : bool
        Enables store snapshots if :obj:`True`.

    ipost : bool
        Enables online postprocessing if :obj:`True`.

    ilesmod : bool
        Enables Large-Eddy methodologies if :obj:`True`.

    ifirst : int
        The number for the first iteration.

    ilast : int
        The number for the last iteration.

    icheckpoint : int
        Frequency for writing restart file.

    ioutput : int
        Frequency for visualization (3D snapshots).

    iprocessing : int
        Frequency for online postprocessing.

        Notes
        -----
            The exactly output may be different according to each flow configuration.
    """

    #
    # # BasicParam
    #

    p_row, p_col = [
        traitlets.Int(default_value=0,
                      min=0).tag(group="BasicParam",
                                 widget=widgets.Dropdown(description=name,
                                                         options=[0]))
        for name in ["p_row", "p_col"]
    ]
    """int: Defines the domain decomposition for (large-scale) parallel computation.

    Notes
    -----
        The product ``p_row * p_col`` must be equal to the number of
        computational cores where Xcompact3d will run.
        More information can be found at `2DECOMP&FFT`_.

        ``p_row = p_col = 0`` activates auto-tunning.

    .. _2DECOMP&FFT:
        http://www.2decomp.org
    """

    itype = traitlets.Int(default_value=10, min=0, max=10).tag(
        group="BasicParam",
        widget=widgets.Dropdown(
            description="itype",
            disabled=True,
            options=[
                ("User", 0),
                ("Lock-exchange", 1),
                ("Taylor-Green Vortex", 2),
                ("Channel", 3),
                ("Periodic Hill", 4),
                ("Cylinder", 5),
                ("Debug Schemes", 6),
                ("Mixing Layer", 7),
                ("Turbulent Jet", 8),
                ("Turbulent Boundary Layer", 9),
                ("Sandbox", 10),
            ],
        ),
    )
    """int: Sets the flow configuration, each one is specified in a different
    ``BC.<flow-configuration>.f90`` file (see `Xcompact3d/src`_), they are:

    * 0 - User configuration;
    * 1 - Turbidity Current in Lock-Release;
    * 2 - Taylor-Green Vortex;
    * 3 - Periodic Turbulent Channel;
    * 5 - Flow around a Cylinder;
    * 6 - Debug Schemes (for developers);
    * 7 - Mixing Layer;
    * 9 - Turbulent Boundary Layer;
    * 10 - `Sandbox`_.

    .. _Xcompact3d/src:
        https://github.com/fschuch/Xcompact3d/tree/master/src
    """

    iin = traitlets.Int(default_value=0, min=0, max=2).tag(
        group="BasicParam",
        widget=widgets.Dropdown(
            description="iin",
            options=[
                ("No random noise", 0),
                ("Random noise", 1),
                ("Random noise with fixed seed", 2),
            ],
        ),
    )
    """int: Defines perturbation at the initial condition:

    * 0 - No random noise (default);
    * 1 - Random noise with amplitude of :obj:`init_noise`;
    * 2 - Random noise with fixed seed
      (important for reproducibility, development and debugging)
      and amplitude of :obj:`init_noise`.

    Notes
    -----
        The exactly behavior may be different according to each flow configuration.
    """

    nx, ny, nz = [
        traitlets.Int(default_value=17, min=0).tag(
            group="BasicParam",
            widget=widgets.Dropdown(description=name, options=possible_mesh),
        ) for name in ["nx", "ny", "nz"]
    ]
    """int: Number of mesh points.

    Notes
    -----
        See :obj:`possible_mesh` and :obj:`possible_mesh_p`.
    """

    xlx, yly, zlz = [
        traitlets.Float(default_value=1.0, min=0).tag(
            group="BasicParam",
            widget=widgets.BoundedFloatText(description=name, min=0.0,
                                            max=1e6),
        ) for name in ["xlx", "yly", "zlz"]
    ]
    """float: Domain size.
    """

    # Docstrings included together with the class
    nclx1 = traitlets.Int(default_value=2, min=0, max=2).tag(
        group="BasicParam",
        widget=widgets.Dropdown(
            description="nclx1",
            options=[("Periodic", 0), ("Free-slip", 1), ("Inflow", 2)],
        ),
    )

    # Docstrings included together with the class
    nclxn = traitlets.Int(default_value=2, min=0, max=2).tag(
        group="BasicParam",
        widget=widgets.Dropdown(
            description="nclxn",
            options=[("Periodic", 0), ("Free-slip", 1), ("Outflow", 2)],
        ),
    )

    # Docstrings included together with the class
    ncly1, nclyn, nclz1, nclzn = [
        traitlets.Int(default_value=2, min=0, max=2).tag(
            group="BasicParam",
            widget=widgets.Dropdown(
                description=name,
                options=[("Periodic", 0), ("Free-slip", 1), ("No-slip", 2)],
            ),
        ) for name in "ncly1 nclyn nclz1 nclzn".split()
    ]

    # Docstrings included together with the class
    ivisu, ipost, ilesmod = [
        traitlets.Bool(default_value=True) for i in range(3)
    ]

    istret = traitlets.Int(default_value=0, min=0, max=3).tag(
        group="BasicParam",
        widget=widgets.Dropdown(
            description="istret",
            options=[
                ("No refinement", 0),
                ("Refinement at the center", 1),
                ("Both sides", 2),
                ("Just near the bottom", 3),
            ],
        ),
    )
    """int: Controls mesh refinement in **y**:

    * 0 - No refinement (default);
    * 1 - Refinement at the center;
    * 2 - Both sides;
    * 3 - Just near the bottom.

    Notes
    -----
        See :obj:`beta`.
    """

    beta = traitlets.Float(default_value=1.0, min=0).tag(
        group="BasicParam",
        widget=widgets.BoundedFloatText(description="beta", min=0.0, max=1e6),
    )
    """float: Refinement factor in **y**.

    Notes
    -----
        Only necessary if :obj:`istret` :math:`\\ne` 0.
    """

    dt = traitlets.Float(default_value=1e-3, min=0.0).tag(
        group="BasicParam",
        widget=widgets.BoundedFloatText(description="dt", min=0.0, max=1e6),
    )
    """float: Time step :math:`(\\Delta t)`.
    """

    # Docstrings included together with the class
    ifirst, ilast = [
        traitlets.Int(default_value=0,
                      min=0).tag(group="BasicParam",
                                 widget=widgets.IntText(description=name))
        for name in ["ifirst", "ilast"]
    ]

    re = traitlets.Float(default_value=1e3).tag(
        group="BasicParam", widget=widgets.FloatText(description="re"))
    """float: Reynolds number :math:`(Re)`.
    """

    init_noise = traitlets.Float(default_value=0.0).tag(
        group="BasicParam", widget=widgets.FloatText(description="init_noise"))
    """float: Random number amplitude at initial condition.

    Notes
    -----
        The exactly behavior may be different according to each flow configuration.

        Only necessary if :obj:`iin` :math:`\\ne` 0.
    """

    inflow_noise = traitlets.Float(default_value=0.0).tag(
        group="BasicParam",
        widget=widgets.FloatText(description="inflow_noise"))
    """float: Random number amplitude at inflow boundary (where :math:`x=0`).

    Notes
    -----
        Only necessary if :obj:`nclx1` is equal to 2.
    """

    ilesmod, ivisu, ipost = [
        traitlets.Int(default_value=1, min=0, max=1).tag(
            group="BasicParam",
            widget=widgets.Dropdown(description=name,
                                    options=[("Off", 0), ("On", 1)]),
        ) for name in ["ilesmod", "ivisu", "ipost"]
    ]

    iibm = traitlets.Int(default_value=0, min=0, max=2).tag(
        group="BasicParam",
        widget=widgets.Dropdown(
            description="iibm",
            options=[("Off", 0), ("Forced to zero", 1),
                     ("Interpolated to zero", 2)],
        ),
    )
    """int: Enables Immersed Boundary Method (IBM):

    * 0 - Off (default);
    * 1 - On with direct forcing method, i.e.,
      it sets velocity to zero inside the solid body;
    * 2 - On with alternating forcing method, i.e, it uses
      Lagrangian Interpolators to define the velocity inside the body
      and imposes no-slip condition at the solid/fluid interface.
    """

    numscalar = traitlets.Int(default_value=0, min=0, max=9).tag(
        group="BasicParam",
        widget=widgets.IntSlider(min=0,
                                 max=9,
                                 description="numscalar",
                                 continuous_update=False),
    )
    """int: Number of scalar fraction, which can have different properties.

    Notes
    -----
        More than 9 will bug Xcompact3d, because it handles the I/O for
        scalar fields with just one digit
    """

    gravx, gravy, gravz = [
        traitlets.Float(default_value=0.0).tag(
            group="BasicParam",
            widget=widgets.FloatText(description=name, disabled=True),
        ) for name in ["gravx", "gravy", "gravz"]
    ]
    """float: Component of the unitary vector pointing in the gravity's direction.
    """

    #
    # # NumOptions
    #

    ifirstder = traitlets.Int(default_value=4, min=1, max=4).tag(
        group="NumOptions",
        widget=widgets.Dropdown(
            description="ifirstder",
            disabled=True,
            options=[
                ("2nd central", 1),
                ("4th central", 1),
                ("4th compact", 1),
                ("6th compact", 4),
            ],
        ),
    )
    """int: Scheme for first order derivative:

    * 1 - 2nd central;
    * 2 - 4th central;
    * 3 - 4th compact;
    * 4 - 6th compact (default).
    """

    isecondder = traitlets.Int(default_value=4, min=1, max=5).tag(
        group="NumOptions",
        widget=widgets.Dropdown(
            description="isecondder",
            disabled=True,
            options=[
                # '2nd central', 1),
                ("6th compact", 4),
                ("hyperviscous 6th", 5),
            ],
        ),
    )
    """int: Scheme for second order derivative:

    * 1 - 2nd central;
    * 2 - 4th central;
    * 3 - 4th compact;
    * 4 - 6th compact (default);
    * 5 - Hyperviscous 6th.
    """

    ipinter = traitlets.Int(3)

    itimescheme = traitlets.Int(default_value=3, min=1, max=7).tag(
        group="NumOptions",
        widget=widgets.Dropdown(
            description="itimescheme",
            options=[
                ("Euler", 1),
                ("AB2", 2),
                ("AB3", 3),
                ("RK3", 5),
                ("Semi-implicit", 7),
            ],
        ),
    )
    """int: Time integration scheme:

    * 1 - Euler;
    * 2 - AB2;
    * 3 - AB3 (default);
    * 5 - RK3;
    * 7 - Semi-implicit.
    """

    nu0nu = traitlets.Float(default_value=4, min=0.0).tag(
        group="NumOptions",
        widget=widgets.BoundedFloatText(description="nu0nu",
                                        min=0.0,
                                        max=1e6,
                                        disabled=True),
    )
    """float: Ratio between hyperviscosity/viscosity at nu.
    """

    cnu = traitlets.Float(default_value=0.44, min=0.0).tag(
        group="NumOptions",
        widget=widgets.BoundedFloatText(description="cnu",
                                        min=0.0,
                                        max=1e6,
                                        disabled=True),
    )
    """float: Ratio between hypervisvosity at :math:`k_m=2/3\\pi` and :math:`k_c= \\pi`.
    """

    #
    # # InOutParam
    #

    irestart = traitlets.Int(default_value=0, min=0, max=1).tag(
        group="InOutParam",
        widget=widgets.Dropdown(description="irestart",
                                options=[("Off", 0), ("On", 1)]),
    )
    """int: Reads initial flow field if equals to 1.
    """

    nvisu = traitlets.Int(default_value=1, min=1).tag(
        group="InOutParam",
        widget=widgets.BoundedIntText(description="nvisu",
                                      min=1,
                                      max=1e9,
                                      disabled=True),
    )
    """int: Size for visual collection.
    """

    icheckpoint, ioutput, iprocessing = [
        traitlets.Int(default_value=1000, min=1).tag(
            group="InOutParam",
            widget=widgets.BoundedIntText(description=name, min=1, max=1e9),
        ) for name in ["icheckpoint", "ioutput", "iprocessing"]
    ]

    ifilenameformat = traitlets.Int(default_value=9, min=1)

    #
    # # ScalarParam
    #

    _iscalar = traitlets.Bool(False)

    # Include widgets for list demands some planning about code design
    sc = traitlets.List(trait=traitlets.Float()).tag(group="ScalarParam")
    """:obj:`list` of :obj:`float`: Schmidt number(s).
    """

    ri = traitlets.List(trait=traitlets.Float()).tag(group="ScalarParam")
    """:obj:`list` of :obj:`float`: Richardson number(s).
    """

    uset = traitlets.List(trait=traitlets.Float()).tag(group="ScalarParam")
    """:obj:`list` of :obj:`float`: Settling velocity(s).
    """

    cp = traitlets.List(trait=traitlets.Float()).tag(group="ScalarParam")
    """:obj:`list` of :obj:`float`: Initial concentration(s).
    """

    scalar_lbound = traitlets.List(trait=traitlets.Float(
        default_value=-1e6)).tag(group="ScalarParam")
    """:obj:`list` of :obj:`float`: Lower scalar bound(s), for clipping methodology.
    """

    scalar_ubound = traitlets.List(trait=traitlets.Float(
        default_value=1e6)).tag(group="ScalarParam")
    """:obj:`list` of :obj:`float`: Upper scalar bound(s), for clipping methodology.
    """

    iibmS = traitlets.Int(default_value=0, min=0, max=3).tag(
        group="ScalarParam",
        widget=widgets.Dropdown(
            description="iibmS",
            options=[
                ("Off", 0),
                ("Forced to zero", 1),
                ("Interpolated to zero", 2),
                ("Interpolated to no-flux", 3),
            ],
        ),
    )
    """int: Enables Immersed Boundary Method (IBM) for scalar field(s):

    * 0 - Off (default);
    * 1 - On with direct forcing method, i.e.,
      it sets scalar concentration to zero inside the solid body;
    * 2 - On with alternating forcing method, i.e, it uses
      Lagrangian Interpolators to define the scalar field inside the body
      and imposes zero value at the solid/fluid interface.
    * 3 - On with alternating forcing method, but now the Lagrangian
      Interpolators are set to impose no-flux for the scalar field at the
      solid/fluid interface.

      .. note:: It is only recommended if the normal vectors to the object's
            faces are aligned with one of the coordinate axes.
    """

    # Docstrings included together with the class
    nclxS1, nclxSn, nclyS1, nclySn, nclzS1, nclzSn = [
        traitlets.Int(default_value=2, min=0, max=2).tag(
            group="ScalarParam",
            widget=widgets.Dropdown(
                description=name,
                options=[("Periodic", 0), ("No-flux", 1), ("Dirichlet", 2)],
            ),
        ) for name in
        ["nclxS1", "nclxSn", "nclyS1", "nclySn", "nclzS1", "nclzSn"]
    ]

    #
    # # LESModel
    #

    jles = traitlets.Int(default_value=0, min=0, max=4).tag(
        group="LESModel",
        widget=widgets.Dropdown(
            description="ilesmod",
            options=[
                ("DNS", 0),
                ("Phys Smag", 1),
                ("Phys WALE", 2),
                ("Phys dyn. Smag", 3),
                ("iSVV", 4),
            ],
        ),
    )
    """int: Chooses LES model, they are:

    * 0 - No model (DNS);
    * 1 - Phys Smag;
    * 2 - Phys WALE;
    * 3 - Phys dyn. Smag;
    * 4 - iSVV.

    """
    #
    # # ibmstuff
    #

    nobjmax = traitlets.Int(default_value=1, min=0).tag(
        group="ibmstuff",
        widget=widgets.IntText(description="nobjmax", disabled=True))
    """int: Maximum number of objects in any direction. It is defined
        automatically at :obj:`gene_epsi_3D`.
    """

    nraf = traitlets.Int(default_value=10, min=1).tag(group="ibmstuff",
                                                      widget=widgets.IntSlider(
                                                          min=1,
                                                          max=25,
                                                          description="nraf"))
    """int: Refinement constant.
    """

    # Auxiliar
    filename = traitlets.Unicode(default_value="input.i3d").tag(
        widget=widgets.Text(description="filename"))
    """str: Filename for the ``.i3d`` file.
    """

    _i3d = traitlets.Dict(
        default_value={
            "BasicParam": {},
            "NumOptions": {},
            "InOutParam": {},
            "Statistics": {},
            "ScalarParam": {},
            "LESModel": {},
            "WallModel": {},
            "ibmstuff": {},
            "ForceCVs": {},
            "CASE": {},
        })

    _mx, _my, _mz = [traitlets.Int(default_value=1, min=1) for i in range(3)]

    dx, dy, dz = [
        traitlets.Float(default_value=0.0625,
                        min=0.0).tag(widget=widgets.BoundedFloatText(
                            description=name, min=0.0, max=1e6))
        for name in ["dx", "dy", "dz"]
    ]
    """float: Mesh resolution.
    """

    _nclx, _ncly, _nclz = [traitlets.Bool() for i in range(3)]
    """bool: Auxiliar variable for boundary condition,
        it is :obj:`True` if Periodic and :obj:`False` otherwise.
    """

    _possible_mesh_x, _possible_mesh_y, _possible_mesh_z = [
        traitlets.List(trait=traitlets.Int(), default_value=possible_mesh)
        for i in range(3)
    ]
    """:obj:`list` of :obj:`int`: Auxiliar variable for mesh points widgets,
        it stores the avalilable options according to the boudary conditions.
    """

    ncores = traitlets.Int(default_value=4,
                           min=1).tag(widget=widgets.BoundedIntText(
                               value=0, min=0, description="ncores", max=1e9))
    """int: Number of computational cores where Xcompact3d will run.
    """

    _possible_p_row, _possible_p_col = [
        traitlets.List(trait=traitlets.Int(),
                       default_value=list(divisorGenerator(4)))
        for i in range(2)
    ]
    """:obj:`list` of :obj:`int`: Auxiliar variable for parallel domain decomposition,
        it stores the avalilable options according to :obj:`ncores`.
    """

    # cfl = traitlets.Float(0.0)
    _size_in_disc = traitlets.Unicode().tag(
        widget=widgets.Text(value="", description="Size", disabled=True))
    """str: Auxiliar variable indicating the demanded space in disc
    """
    def __init__(self, **kwargs):
        """Initializes the Parameters Class.

        Parameters
        ----------
        **kwargs
            Keyword arguments for valid atributes.

        Raises
        -------
        KeyError
            Exception is raised when an Keyword arguments is not a valid atribute.

        Examples
        -------

        There are a few ways to initialize the class.

        First, calling it with no
        arguments initializes all variables with default value:

        >>> prm = xcompact3d_toolbox.Parameters()

        It is possible to set any values afterwards (including new atributes):

        >>> prm.re = 1e6

        Second, we can specify some values, and let the missing ones be
        initialized with default value:

        >>> prm = x3d.Parameters(
        ...     filename = 'example.i3d',
        ...     itype = 10,
        ...     nx = 257,
        ...     ny = 129,
        ...     nz = 32,
        ...     xlx = 15.0,
        ...     yly = 10.0,
        ...     zlz = 3.0,
        ...     nclx1 = 2,
        ...     nclxn = 2,
        ...     ncly1 = 1,
        ...     nclyn = 1,
        ...     nclz1 = 0,
        ...     nclzn = 0,
        ...     re = 300.0,
        ...     init_noise = 0.0125,
        ...     dt = 0.0025,
        ...     ilast = 45000,
        ...     ioutput = 200,
        ...     iprocessing = 50
        ... )

        And finally, it is possible to read the parameters from the disc:

        >>> prm = xcompact3d_toolbox.Parameters(filename = 'example.i3d')
        >>> prm.read()

        """

        super(Parameters, self).__init__()

        # Boundary conditions are high priority in order to avoid bugs
        for bc in "nclx1 nclxn ncly1 nclyn nclz1 nclzn".split():
            if bc in kwargs:
                setattr(self, bc, kwargs[bc])

        for key, arg in kwargs.items():
            if key not in self.trait_names():
                raise KeyError(f"There is no parameter named {key}!")
            setattr(self, key, arg)

        # self.link_widgets()

    def __call__(self, *args):
        """Returns widgets on demand.

        Parameters
        ----------
        *args : str
            Name(s) for the desired widget(s).

        Returns
        -------
        :obj:`ipywidgets.VBox`
            Widgets for an user friendly interface.

        Raises
        -------
        KeyError
            Exception is raised if an argument is not a valid atribute.
            An attribute is considered valid if it has a ``tag`` named ``widget``.

        Examples
        -------

        >>> prm = xcompact3d_toolbox.Parameters()
        >>> prm()

        >>> prm('nx', 'xlx', 'dx', 'nclx1', 'nclxn')

        """

        if len(args) == 0:
            dim = "x y z".split()

            return widgets.VBox([
                widgets.HTML(value="<h1>Xcompact3d Parameters</h1>"),
                widgets.HBox([
                    self.trait_metadata("filename", "widget"),
                    widgets.Button(description="Read",
                                   disabled=True,
                                   icon="file-upload"),
                    widgets.Button(description="Write",
                                   disabled=True,
                                   icon="file-download"),
                    widgets.Button(description="Run",
                                   disabled=True,
                                   icon="rocket"),
                    widgets.Button(description="Sync",
                                   disabled=True,
                                   icon="sync"),
                ]),
                widgets.HTML(value="<h2>BasicParam</h2>"),
                widgets.HBox([
                    self.trait_metadata(d, "widget")
                    for d in "itype re".split()
                ]),
                widgets.HBox([
                    self.trait_metadata(d, "widget")
                    for d in "iin init_noise inflow_noise".split()
                ]),
                widgets.HTML(value="<h3>Domain Decomposition</h3>"),
                widgets.HBox([
                    self.trait_metadata(f"{d}", "widget")
                    for d in "ncores p_row p_col".split()
                ]),
                widgets.HTML(value="<h3>Temporal discretization</h3>"),
                widgets.HBox([
                    self.trait_metadata(d, "widget")
                    for d in "ifirst ilast dt".split()
                ]),
                widgets.HTML(value="<h3>InOutParam</h3>"),
                widgets.HBox([
                    self.trait_metadata(d, "widget")
                    for d in "irestart nvisu _size_in_disc".split()
                ]),
                widgets.HBox([
                    self.trait_metadata(d, "widget")
                    for d in "icheckpoint ioutput iprocessing".split()
                ]),
                widgets.HTML(value="<h3>Spatial discretization</h3>"),
                widgets.HBox(
                    [self.trait_metadata(f"n{d}", "widget") for d in dim]),
                widgets.HBox(
                    [self.trait_metadata(f"{d}l{d}", "widget") for d in dim]),
                widgets.HBox(
                    [self.trait_metadata(f"d{d}", "widget") for d in dim]),
                widgets.HBox(
                    [self.trait_metadata(f"ncl{d}1", "widget") for d in dim]),
                widgets.HBox(
                    [self.trait_metadata(f"ncl{d}n", "widget") for d in dim]),
                widgets.HBox([
                    self.trait_metadata(d, "widget")
                    for d in "istret beta".split()
                ]),
                widgets.HTML(value="<h2>NumOptions</h2>"),
                widgets.HBox([
                    self.trait_metadata(d, "widget")
                    for d in "ifirstder isecondder itimescheme".split()
                ]),
                widgets.HBox([
                    self.trait_metadata(d, "widget")
                    for d in "ilesmod nu0nu cnu".split()
                ]),
                widgets.HTML(value="<h2>ScalarParam</h2>"),
                widgets.HBox([self.trait_metadata("numscalar", "widget")]),
                widgets.HBox(
                    [self.trait_metadata(f"ncl{d}S1", "widget") for d in dim]),
                widgets.HBox(
                    [self.trait_metadata(f"ncl{d}Sn", "widget") for d in dim]),
                widgets.HBox(
                    [self.trait_metadata(f"grav{d}", "widget") for d in dim]),
                widgets.HBox([
                    self.trait_metadata(d, "widget") for d in "iibmS".split()
                ]),
                widgets.HTML(
                    value=
                    "<strong>cp, us, sc, ri, scalar_lbound & scalar_ubound</strong> are lists with length numscalar, set them properly on the code."
                ),
                widgets.HTML(value="<h2>IBMStuff</h2>"),
                widgets.HBox([
                    self.trait_metadata(d, "widget")
                    for d in "iibm nraf nobjmax".split()
                ]),
            ])

        widgets_list = []
        for name in args:
            if name not in self.trait_names():
                raise KeyError(f"There is no parameter named {name}!")
            widget = self.trait_metadata(name, "widget")
            if widget != None:
                widgets_list.append(widget)

        return widgets.VBox(widgets_list)

    @traitlets.validate("nx")
    def _validade_mesh_nx(self, proposal):
        _validate_mesh(proposal["value"], self._nclx, self.nclx1, self.nclxn,
                       "x")
        return proposal["value"]

    @traitlets.validate("ny")
    def _validade_mesh_ny(self, proposal):
        _validate_mesh(proposal["value"], self._ncly, self.ncly1, self.nclyn,
                       "y")
        return proposal["value"]

    @traitlets.validate("nz")
    def _validade_mesh_nz(self, proposal):
        _validate_mesh(proposal["value"], self._nclz, self.nclz1, self.nclzn,
                       "z")
        return proposal["value"]

    @traitlets.observe("dx", "nx", "xlx", "dy", "ny", "yly", "dz", "nz", "zlz")
    def _observe_resolution(self, change):
        # for name in "name new old".split():
        #     print(f"    {name:>5} : {change[name]}")

        dim = change["name"][-1]  # It will be x, y or z
        #
        if change["name"] == f"n{dim}":
            if getattr(self, f"_ncl{dim}"):
                setattr(self, f"_m{dim}", change["new"])
            else:
                setattr(self, f"_m{dim}", change["new"] - 1)
            setattr(
                self,
                f"d{dim}",
                getattr(self, f"{dim}l{dim}") / getattr(self, f"_m{dim}"),
            )
        if change["name"] == f"d{dim}":
            new_l = change["new"] * getattr(self, f"_m{dim}")
            if new_l != getattr(self, f"{dim}l{dim}"):
                setattr(self, f"{dim}l{dim}", new_l)
        if change["name"] == f"{dim}l{dim}":
            new_d = change["new"] / getattr(self, f"_m{dim}")
            if new_d != getattr(self, f"d{dim}"):
                setattr(self, f"d{dim}", new_d)

    @traitlets.observe(
        "nclx1",
        "nclxn",
        "nclxS1",
        "nclxSn",
        "ncly1",
        "nclyn",
        "nclyS1",
        "nclySn",
        "nclz1",
        "nclzn",
        "nclzS1",
        "nclzSn",
    )
    def _observe_bc(self, change):
        #
        dim = change["name"][3]  # It will be x, y or z
        #
        if change["new"] == 0:
            for i in f"ncl{dim}1 ncl{dim}n ncl{dim}S1 ncl{dim}Sn".split():
                setattr(self, i, 0)
            setattr(self, f"_ncl{dim}", True)
        if change["old"] == 0 and change["new"] != 0:
            for i in f"ncl{dim}1 ncl{dim}n ncl{dim}S1 ncl{dim}Sn".split():
                setattr(self, i, change["new"])
            setattr(self, f"_ncl{dim}", False)

    @traitlets.observe("_nclx", "_ncly", "_nclz")
    def _observe_periodicity(self, change):
        #
        dim = change["name"][-1]  # It will be x, y or z
        #
        if change["new"]:
            tmp = getattr(self, f"n{dim}") - 1
            setattr(self, f"_possible_mesh_{dim}", possible_mesh_p)
            setattr(self, f"n{dim}", tmp)
        else:
            tmp = getattr(self, f"n{dim}") + 1
            setattr(self, f"_possible_mesh_{dim}", possible_mesh)
            setattr(self, f"n{dim}", tmp)

    @traitlets.observe("p_row", "p_col", "ncores")
    def _observe_2Decomp(self, change):
        if change["name"] == "ncores":
            possible = list(divisorGenerator(change["new"]))
            self._possible_p_row = possible
            self._possible_p_col = possible
            self.p_row, self.p_col = 0, 0
        elif change["name"] == "p_row":
            try:
                self.p_col = self.ncores // self.p_row
            except:
                self.p_col = 0
        elif change["name"] == "p_col":
            try:
                self.p_row = self.ncores // self.p_col
            except:
                self.p_row = 0

    @traitlets.observe("ilesmod")
    def _observe_ilesmod(self, change):
        if change["new"] == 0:
            self.nu0nu, self.cnu, self.isecondder = 4.0, 0.44, 4
            self.trait_metadata("nu0nu", "widget").disabled = True
            self.trait_metadata("cnu", "widget").disabled = True
            self.trait_metadata("isecondder", "widget").disabled = True
        else:
            self.trait_metadata("nu0nu", "widget").disabled = False
            self.trait_metadata("cnu", "widget").disabled = False
            self.trait_metadata("isecondder", "widget").disabled = False

    @traitlets.observe("numscalar")
    def _observe_numscalar(self, change):
        self._iscalar = True if change["new"] == 0 else False

    @traitlets.observe(
        "numscalar",
        "nx",
        "ny",
        "nz",
        "nvisu",
        "icheckpoint",
        "ioutput",
        "iprocessing",
        "ilast",
    )
    def _observe_size_in_disc(self, change):
        def convert_bytes(num):
            """
            this function will convert bytes to MB.... GB... etc
            """
            step_unit = 1000.0  # 1024 bad the size

            for x in ["bytes", "KB", "MB", "GB", "TB"]:
                if num < step_unit:
                    return "%3.1f %s" % (num, x)
                num /= step_unit

        prec = 4 if mytype == np.float32 else 8

        # Restart Size from tools.f90
        count = 3 + self.numscalar  # ux, uy, uz, phi
        # Previous time-step if necessary
        if self.itimescheme in [3, 7]:
            count *= 3
        elif self.itimescheme == 2:
            count *= 2
        count += 1  # pp
        count *= (self.nx * self.ny * self.nz * prec *
                  (self.ilast // self.icheckpoint - 1))

        # 3D from visu.f90: ux, uy, uz, pp and phi
        count += ((4 + self.numscalar) * self.nx * self.ny * self.nz * prec *
                  self.ilast // self.ioutput)

        # 2D planes from BC.Sandbox.f90
        if self.itype == 10:
            # xy planes avg and central plane for ux, uy, uz and phi
            count += (2 * (3 + self.numscalar) * self.nx * self.ny * prec *
                      self.ilast // self.iprocessing)
            # xz planes avg, top and bot for ux, uy, uz and phi
            count += (3 * (3 + self.numscalar) * self.nx * self.nz * prec *
                      self.ilast // self.iprocessing)

        self._size_in_disc = convert_bytes(count)

    def _class_to_dict(self):
        for name in self.trait_names():
            group = self.trait_metadata(name, "group")
            if group != None:
                if group not in self._i3d.keys():
                    self._i3d[group] = {}
                self._i3d[group][name] = getattr(self, name)

    def _dict_to_class(self):

        # Boundary conditions are high priority in order to avoid bugs
        for bc in "nclx1 nclxn ncly1 nclyn nclz1 nclzn".split():
            if bc in self._i3d["BasicParam"]:
                setattr(self, bc, self._i3d["BasicParam"][bc])

        for name in self.trait_names():
            try:
                group = self.trait_metadata(name, "group")
                setattr(self, name, self._i3d[group][name])
            except:
                # print(f'{name} not in dictionary')
                pass

    def read(self):
        """Reads all valid attributes from an ``.i3d`` file.

        An attribute is considered valid if it has a ``tag`` named ``group``,
        witch assigns it to the respective namespace at the ``.i3d`` file.

        Examples
        -------

        >>> prm = xcompact3d_toolbox.Parameters(filename = 'example.i3d')
        >>> prm.read()

        """
        self._i3d = i3d_to_dict(self.filename)
        self._dict_to_class()

    def write(self):
        """Writes all valid attributes to an ``.i3d`` file.

        An attribute is considered valid if it has a ``tag`` named ``group``,
        witch assigns it to the respective namespace at the ``.i3d`` file.

        Examples
        -------

        >>> prm = xcompact3d_toolbox.Parameters(filename = 'example.i3d')
        >>> prm.write()

        """
        self._class_to_dict()
        dict_to_i3d(self._i3d, self.filename)

    def link_widgets(self, silence=True):
        """Creates a two-way link between the value of an attribute and its widget.

        This method is called at initialization, but provides an easy way to link
        any new variable.

        Parameters
        ----------
        silence : bool
            Print error to screen if :obj:`True`.

        Examples
        -------

        >>> prm = xcompact3d_toolbox.Parameters(filename = 'example.i3d')
        >>> prm.link_widgets()

        """
        # Create two-way link between variable and widget
        for name in self.trait_names():
            try:
                traitlets.link((self, name),
                               (self.trait_metadata(name, "widget"), "value"))
            except:
                if not silence:
                    print(f"Widget not linked for {name}")

        for dim in ["x", "y", "z"]:
            traitlets.link(
                (self, f"_possible_mesh_{dim}"),
                (self.trait_metadata(f"n{dim}", "widget"), "options"),
            )
        for name in ["p_row", "p_col"]:
            traitlets.link(
                (self, f"_possible_{name}"),
                (self.trait_metadata(f"{name}", "widget"), "options"),
            )

        for name in self.trait_names():
            if name == "numscalar":
                continue
            group = self.trait_metadata(name, "group")
            if group == "ScalarParam":
                try:
                    traitlets.link(
                        (self, "_iscalar"),
                        (self.trait_metadata(name, "widget"), "disabled"),
                    )
                except:
                    if not silence:
                        print(f"Widget not linked to numscalar for {name}")
        # Try adding a description
        for name in self.trait_names():
            if name in description:
                try:
                    self.trait_metadata(
                        name, "widget").description_tooltip = description[name]
                except:
                    pass

    def write_xdmf(self):
        """Writes four xdmf files:

        * ``./data/3d_snapshots.xdmf`` for 3D snapshots in ``./data/3d_snapshots/*``;
        * ``./data/xy_planes.xdmf`` for planes in ``./data/xy_planes/*``;
        * ``./data/xz_planes.xdmf`` for planes in ``./data/xz_planes/*``;
        * ``./data/yz_planes.xdmf`` for planes in ``./data/yz_planes/*``.

        Shape and time are inferted from folder structure and filenames.
        File list is obtained automatically with :obj:`glob`.

        .. note:: This is only compatible with the new filename structure,
            the conversion is exemplified in `convert_filenames_x3d_toolbox`_.

        .. _`convert_filenames_x3d_toolbox`: https://gist.github.com/fschuch/5a05b8d6e9787d76655ecf25760e7289

        Parameters
        ----------
        prm : :obj:`xcompact3d_toolbox.parameters.Parameters`
            Contains the computational and physical parameters.

        Examples
        -------

        >>> prm = x3d.Parameters()
        >>> prm.write_xdmf()

        """
        write_xdmf(self)

    @property
    def get_mesh(self):
        """Get mesh point locations for the three coordinates.

        Returns
        -------
        :obj:`dict` of :obj:`numpy.ndarray`
            It contains the mesh point locations at three dictionary keys,
            for **x**, **y** and **z**.

        Examples
        -------

        >>> prm = xcompact3d_toolbox.Parameters()
        >>> prm.get_mesh

        """
        return get_mesh(self)
Ejemplo n.º 24
0
class VizHistogramState(VizBaseState):
    x_expression = traitlets.Unicode()
    x_slice = traitlets.CInt(None, allow_none=True)
    type = traitlets.CaselessStrEnum(['count', 'min', 'max', 'mean'],
                                     default_value='count')
    aux = traitlets.Unicode(None, allow_none=True)
    groupby = traitlets.Unicode(None, allow_none=True)
    groupby_normalize = traitlets.Bool(False, allow_none=True)
    x_min = traitlets.CFloat(None, allow_none=True)
    x_max = traitlets.CFloat(None, allow_none=True)
    grid = traitlets.Any().tag(**serialize_numpy)
    grid_sliced = traitlets.Any().tag(**serialize_numpy)
    x_centers = traitlets.Any().tag(**serialize_numpy)
    x_shape = traitlets.CInt(None, allow_none=True)

    #centers = traitlets.Any()

    def __init__(self, ds, **kwargs):
        super(VizHistogramState, self).__init__(ds, **kwargs)
        self.observe(lambda x: self.signal_slice.emit(self), ['x_slice'])
        self.observe(lambda x: self.calculate_limits(),
                     ['x_expression', 'type', 'aux'])
        # no need for recompute
        # self.observe(lambda x: self.calculate_grid(), ['groupby', 'shape', 'groupby_normalize'])
        # self.observe(lambda x: self.calculate_grid(), ['groupby', 'shape', 'groupby_normalize'])

        self.observe(lambda x: self._update_grid(),
                     ['x_min', 'x_max', 'shape'])
        if self.x_min is None and self.x_max is None:
            self.calculate_limits()
        else:
            self._calculate_centers()

    def bin_parameters(self):
        yield self.x_expression, self.x_shape or self.shape, (
            self.x_min, self.x_max), self.x_slice

    def state_get(self):
        #         return {name: self.trait_metadata('grid', 'serialize', ident)(getattr(self, name) for name in self.trait_names()}
        state = {}
        for name in self.trait_names():
            serializer = self.trait_metadata(name, 'serialize', ident)
            value = serializer(getattr(self, name))
            state[name] = value
        return state

    def state_set(self, state):
        for name in self.trait_names():
            if name in state:
                deserializer = self.trait_metadata(name, 'deserialize', ident)
                value = deserializer(state[name])
                setattr(self, name, value)

    def calculate_limits(self):
        self._calculate_limits('x', 'x_expression')
        self.signal_regrid.emit(
            None)  # TODO this is also called in the ctor, unnec work

    def limits_changed(self, change):
        self.signal_regrid.emit(
            None)  # TODO this is also called in the ctor, unnec work

    @vaex.jupyter.debounced()
    def _update_grid(self):
        self._calculate_centers()
        self.signal_regrid.emit(None)
Ejemplo n.º 25
0
class XGBoostModel(state.HasState):
    '''The XGBoost algorithm.

    XGBoost is an optimized distributed gradient boosting library designed to be
    highly efficient, flexible and portable. It implements machine learning
    algorithms under the Gradient Boosting framework. XGBoost provides a parallel
    tree boosting (also known as GBDT, GBM) that solve many data science
    problems in a fast and accurate way.
    (https://github.com/dmlc/xgboost)

    Example:

    >>> import vaex
    >>> import vaex.ml.xgboost
    >>> df = vaex.ml.datasets.load_iris()
    >>> features = ['sepal_width', 'petal_length', 'sepal_length', 'petal_width']
    >>> df_train, df_test = vaex.ml.train_test_split(df)
    >>> params = {
        'max_depth': 5,
        'learning_rate': 0.1,
        'objective': 'multi:softmax',
        'num_class': 3,
        'subsample': 0.80,
        'colsample_bytree': 0.80,
        'silent': 1}
    >>> booster = vaex.ml.xgboost.XGBoostModel(features=features, num_boost_round=100, params=params)
    >>> booster.fit(df_train, 'class_')
    >>> df_train = booster.transform(df_train)
    >>> df_train.head(3)
    #    sepal_length    sepal_width    petal_length    petal_width    class_    xgboost_prediction
    0             5.4            3               4.5            1.5         1                     1
    1             4.8            3.4             1.6            0.2         0                     0
    2             6.9            3.1             4.9            1.5         1                     1
    >>> df_test = booster.transform(df_test)
    >>> df_test.head(3)
    #    sepal_length    sepal_width    petal_length    petal_width    class_    xgboost_prediction
    0             5.9            3               4.2            1.5         1                     1
    1             6.1            3               4.6            1.4         1                     1
    2             6.6            2.9             4.6            1.3         1                     1
    '''

    features = traitlets.List(
        traitlets.Unicode(),
        help='List of features to use when fitting the XGBoostModel.')
    num_boost_round = traitlets.CInt(help='Number of boosting iterations.')
    params = traitlets.Dict(
        help='A dictionary of parameters to be passed on to the XGBoost model.'
    )
    prediction_name = traitlets.Unicode(
        default_value='xgboost_prediction',
        help='The name of the virtual column housing the predictions.')

    def __call__(self, *args):
        data2d = np.vstack([arg.astype(np.float64) for arg in args]).T.copy()
        dmatrix = xgboost.DMatrix(data2d)
        return self.booster.predict(dmatrix)

    def transform(self, df):
        '''Transform a DataFrame such that it contains the predictions of the XGBoostModel in form of a virtual column.

        :param df: A vaex DataFrame. It should have the same columns as the DataFrame used to train the model.

        :return copy: A shallow copy of the DataFrame that includes the XGBoostModel prediction as a virtual column.
        :rtype: DataFrame
        '''
        copy = df.copy()
        lazy_function = copy.add_function('xgboost_prediction_function', self)
        expression = lazy_function(*self.features)
        copy.add_virtual_column(self.prediction_name, expression, unique=False)
        return copy

    def fit(self,
            df,
            target,
            evals=(),
            early_stopping_rounds=None,
            evals_result=None,
            verbose_eval=False,
            **kwargs):
        '''Fit the XGBoost model given a DataFrame.
        This method accepts all key word arguments for the xgboost.train method.

        :param df: A vaex DataFrame containing the training features.
        :param target: The column name of the target variable.
        :param evals: A list of pairs (DataFrame, string).
        List of items to be evaluated during training, this allows user to watch performance on the validation set.
        :param int early_stopping_rounds: Activates early stopping.
        Validation error needs to decrease at least every *early_stopping_rounds* round(s) to continue training.
        Requires at least one item in *evals*. If there's more than one, will use the last. Returns the model
        from the last iteration (not the best one).
        :param dict evals_result: A dictionary storing the evaluation results of all the items in *evals*.
        :param bool verbose_eval: Requires at least one item in *evals*.
        If *verbose_eval* is True then the evaluation metric on the validation set is printed at each boosting stage.
        '''
        data = df[self.features].values
        target_data = df.evaluate(target)
        dtrain = xgboost.DMatrix(data, target_data)
        if evals is not None:
            evals = [list(elem) for elem in evals]
            for item in evals:
                data = item[0][self.features].values
                target_data = item[0].evaluate(target)
                item[0] = xgboost.DMatrix(data, target_data)
        else:
            evals = ()

        # This does the actual training / fitting of the xgboost model
        self.booster = xgboost.train(
            params=self.params,
            dtrain=dtrain,
            num_boost_round=self.num_boost_round,
            evals=evals,
            early_stopping_rounds=early_stopping_rounds,
            evals_result=evals_result,
            verbose_eval=verbose_eval,
            **kwargs)

    def predict(self, df, **kwargs):
        '''Provided a vaex DataFrame, get an in-memory numpy array with the predictions from the XGBoost model.
        This method accepts the key word arguments of the predict method from XGBoost.

        :returns: A in-memory numpy array containing the XGBoostModel predictions.
        :rtype: numpy.array
        '''
        data = df[self.features].values
        dmatrix = xgboost.DMatrix(data)
        return self.booster.predict(dmatrix, **kwargs)

    def state_get(self):
        filename = tempfile.mktemp()
        self.booster.save_model(filename)
        with open(filename, 'rb') as f:
            data = f.read()
        return dict(tree_state=base64.encodebytes(data).decode('ascii'),
                    substate=super(XGBoostModel, self).state_get())

    def state_set(self, state):
        super(XGBoostModel, self).state_set(state['substate'])
        data = base64.decodebytes(state['tree_state'].encode('ascii'))
        filename = tempfile.mktemp()
        with open(filename, 'wb') as f:
            f.write(data)
        self.booster = xgboost.Booster(model_file=filename)
Ejemplo n.º 26
0
class PlotBase(widgets.Widget):

    x = traitlets.Unicode(allow_none=False).tag(sync=True)
    y = traitlets.Unicode(allow_none=True).tag(sync=True)
    z = traitlets.Unicode(allow_none=True).tag(sync=True)
    w = traitlets.Unicode(allow_none=True).tag(sync=True)
    vx = traitlets.Unicode(allow_none=True).tag(sync=True)
    vy = traitlets.Unicode(allow_none=True).tag(sync=True)
    vz = traitlets.Unicode(allow_none=True).tag(sync=True)
    smooth_pre = traitlets.CFloat(None, allow_none=True).tag(sync=True)
    smooth_post = traitlets.CFloat(None, allow_none=True).tag(sync=True)
    what = traitlets.Unicode(allow_none=False).tag(sync=True)
    vcount_limits = traitlets.List([None, None],
                                   allow_none=True).tag(sync=True)
    f = traitlets.Unicode(allow_none=True)
    grid_limits = traitlets.List(allow_none=True)
    grid_limits_min = traitlets.CFloat(None, allow_none=True)
    grid_limits_max = traitlets.CFloat(None, allow_none=True)

    def __init__(self,
                 backend,
                 dataset,
                 x,
                 y=None,
                 z=None,
                 w=None,
                 grid=None,
                 limits=None,
                 shape=128,
                 what="count(*)",
                 f=None,
                 vshape=16,
                 selection=None,
                 grid_limits=None,
                 normalize=None,
                 colormap="afmhot",
                 figure_key=None,
                 fig=None,
                 what_kwargs={},
                 grid_before=None,
                 vcount_limits=None,
                 show_drawer=False,
                 controls_selection=True,
                 model=None,
                 x_col=None,
                 y_col=None,
                 x_label='Access Number',
                 y_label='Address',
                 update_stats=None,
                 **kwargs):

        super(PlotBase, self).__init__(x=x,
                                       y=y,
                                       z=z,
                                       w=w,
                                       what=what,
                                       vcount_limits=vcount_limits,
                                       grid_limits=grid_limits,
                                       f=f,
                                       **kwargs)
        self.backend = backend
        self.vgrids = [None, None, None]
        self.vcount = None
        self.dataset = dataset
        self.limits = self.get_limits(limits)
        self.shape = shape
        self.shape = 512
        self.selection = selection
        self.grid_limits_visible = None
        self.normalize = normalize
        self.colormap = colormap
        self.what_kwargs = what_kwargs
        self.grid_before = grid_before
        self.figure_key = figure_key
        self.fig = fig
        self.vshape = vshape
        self.scales = [[self.limits[0][0], self.limits[0][1]],
                       [self.limits[1][0], self.limits[1][1]]]

        self._new_progressbar()

        self.output = widgets.Output()

        def output_changed(*ignore):
            self.widget.new_output = True

        self.output.observe(output_changed, 'outputs')
        # with self.output:
        if 1:
            self._cleanups = []

            self.progress = widgets.FloatProgress(value=0.0,
                                                  min=0.0,
                                                  max=1.0,
                                                  step=0.01)
            self.progress.layout.width = "95%"
            self.progress.layout.max_width = '500px'
            self.progress.description = "progress"

            self.toolbar = v.Row(pa_1=True, children=[])

            # Vaextended arguments
            if [x for x in (model, x_col, y_col, update_stats) if x == None]:
                raise Exception(
                    'The following arguments are required for using plot_widget() with Vaextended: model, x_col, y_col, update_stats\n\nSee docs/README_VAEXTENDED.md for more information'
                )

            self.model = model
            self.x_col = x_col
            self.y_col = y_col
            self.x_label = x_label
            self.y_label = y_label
            self.cache_size = self.model.cache_size()
            self.update_stats = update_stats

            self.backend.create_widget(self.output, self, self.dataset,
                                       self.limits)

            self.widget = PlotTemplate(components={
                'main-widget':
                widgets.VBox([self.backend.widget, self.progress,
                              self.output]),
                'output-widget':
                self.output,
                'toolbar':
                self.toolbar,
                'default_title':
                self.model.plot_title(),
                'main-legend':
                self.model.legend.widgets,
                'legend-control':
                self.model.legend.legend_button
            },
                                       model=show_drawer)
            if grid is None:
                self.update_grid()
            else:
                self.grid = grid

            # checkboxes work because of this
            callback = self.dataset.signal_selection_changed.connect(
                lambda *x: self.update_grid())

        def _on_limits_change(*args):
            self._progressbar.cancel()
            self.update_grid()

        self.backend.observe(_on_limits_change, "limits")
        for attrname in "x y z vx vy vz".split():

            def _on_change(change, attrname=attrname):
                limits_index = {'x': 0, 'y': 1, 'z': 2}.get(attrname)
                if limits_index is not None:
                    self.backend.limits[limits_index] = None
                self.update_grid()

            self.observe(_on_change, attrname)
        # self.update_image() # sometimes bqplot doesn't update the image correcly

    @debounced(0.3, method=True)
    def hide_progress(self):
        self.progress.layout.visibility = 'hidden'

    def get_limits(self, limits):
        return self.dataset.limits(self.get_binby(), limits)

    def active_selections(self):
        selections = _ensure_list(self.selection)

        def translate(selection):
            if selection is False:
                selection = None
            if selection is True:
                selection = "default"
            return selection

        selections = map(translate, selections)
        selections = list([
            s for s in selections
            if self.dataset.has_selection(s) or s in [False, None]
        ])
        if not selections:
            selections = [False]
        return selections

    def show(self):
        display(self.widget)

    def add_to_toolbar(self, widget):
        self.toolbar.children += [widget]
        # TODO: find out why we need to do this, is this a bug?
        self.toolbar.send_state('children')

    def _progress(self, v):
        self.progress.value = v

    def _new_progressbar(self):
        def update(v):
            with self.output:
                self.progress.layout.visibility = 'visible'
                import IPython
                ipython = IPython.get_ipython()
                if ipython is not None:  # for testing
                    ipython.kernel.do_one_iteration()
                self.progress.value = v
                if v == 1:
                    self.hide_progress()
                return not self._progressbar.cancelled

        self._progressbar = vaex.utils.progressbars(False,
                                                    next=update,
                                                    name="bqplot")

    def get_shape(self):
        return _expand_shape(self.shape, len(self.get_binby()))

    def get_vshape(self):
        return _expand_shape(self.vshape, len(self.get_binby()))

    @debounced(.5, method=True)
    def update_grid(self):
        with self.output:
            limits = self.backend.limits[:self.backend.dim]
            xyz = [self.x, self.y, self.z]
            for i, limit in enumerate(limits):
                if limits[i] is None:
                    limits[i] = self.dataset.limits(xyz[i], delay=True)

            @delayed
            def limits_done(limits):
                with self.output:
                    self.limits[:self.backend.dim] = np.array(limits).tolist()
                    limits_backend = copy.deepcopy(self.backend.limits)
                    limits_backend[:self.backend.
                                   dim] = self.limits[:self.backend.dim]
                    self.backend.limits = limits_backend
                    self._update_grid()

            limits_done(delayed_list(limits))
            self._execute()

    def _update_grid(self):
        with self.output:
            self._progressbar.cancel()
            self._new_progressbar()
            current_pb = self._progressbar
            delay = True
            promises = []
            pb = self._progressbar.add("grid")
            result = self.dataset._stat(binby=self.get_binby(),
                                        what=self.what,
                                        limits=self.limits,
                                        shape=self.get_shape(),
                                        progress=pb,
                                        selection=self.active_selections(),
                                        delay=True)
            if delay:
                promises.append(result)
            else:
                self.grid = result

            vs = [self.vx, self.vy, self.vz]
            for i, v in enumerate(vs):
                result = None
                if v:
                    result = self.dataset.mean(
                        v,
                        binby=self.get_binby(),
                        limits=self.limits,
                        shape=self.get_vshape(),
                        progress=self._progressbar.add("v" + str(i)),
                        selection=self.active_selections(),
                        delay=delay)
                if delay:
                    promises.append(result)
                else:
                    self.vgrids[i] = result
            result = None
            if any(vs):
                expr = "*".join([v for v in vs if v])
                result = self.dataset.count(
                    expr,
                    binby=self.get_binby(),
                    limits=self.limits,
                    shape=self.get_vshape(),
                    progress=self._progressbar.add("vcount"),
                    selection=self.active_selections(),
                    delay=delay)
            if delay:
                promises.append(result)
            else:
                self.vgrids[i] = result

            @delayed
            def assign(grid, vx, vy, vz, vcount):
                with self.output:
                    if not current_pb.cancelled:  # TODO: remote dataset jobs cannot be cancelled
                        self.progress.value = 0
                        self.grid = grid
                        self.vgrids = [vx, vy, vz]
                        self.vcount = vcount
                        self._update_image()

            if delay:
                for promise in promises:
                    if promise:
                        promise.end()
                assign(*promises).end()
                self._execute()
            else:
                self._update_image()

    @debounced(0.05, method=True)
    def _execute(self):
        with self.output:
            self.dataset.execute()

    @debounced(0.5, method=True)
    def update_image(self):
        self._update_image()

    def _update_image(self):
        with self.output:
            grid = self.get_grid().copy()  # we may modify inplace
            f = _parse_f(self.f)
            with np.errstate(divide='ignore', invalid='ignore'):
                fgrid = f(grid)
            self.grid_limits = [0, 8]

            y_lo, y_hi = self.backend.limits[1]
            x_lo, x_hi = self.backend.limits[0]
            diffy = y_hi - y_lo
            diffx = x_hi - x_lo
            new_size = 0
            curr_scale = max(diffx, diffy)
            new_size = int(min(32, self.shape // curr_scale))
            if new_size > 1:
                row = col = 0
                rows = len(fgrid[0])
                cols = len(fgrid[0][0])
                n_fgrid = copy.deepcopy(fgrid)
                for row in range(rows):
                    for col in range(cols):
                        val = fgrid[0][row][col]
                        if val != 0:
                            for i in range(row, min(rows, row + new_size)):
                                for j in range(col, min(cols, col + new_size)):
                                    if fgrid[0][i][j] == 0:
                                        n_fgrid[0][i][j] = val
                fgrid = n_fgrid

            cache_fraction = min(1, self.cache_size / diffy)
            lim = cache_fraction * fgrid.shape[1]
            for i in range(int(lim)):
                fgrid[0][0][i] = 2.4

            ngrid, fmin, fmax = self.normalise(fgrid)
            if self.backend.wants_colors():
                color_grid = self.colorize(ngrid)
                if len(color_grid.shape) > 3:
                    if len(color_grid.shape) == 4:
                        if color_grid.shape[0] > 1:
                            color_grid = vaex.image.fade(color_grid[::-1])
                        else:
                            color_grid = color_grid[0]
                    else:
                        raise ValueError(
                            "image shape is %r, don't know what to do with that, expected (L, M, N, 3)"
                            % (color_grid.shape, ))
                I = np.rot90(color_grid).copy()
                self.backend.update_image(I)
            else:
                self.backend.update_image(ngrid[-1])

    def get_grid(self):
        return self.grid

    def get_vgrids(self):
        return self.vgrids[0], self.vgrids[1], self.vgrids[2], self.vcount

    def colorize(self, grid):
        return _parse_reduction("colormap", self.colormap, [])(grid)

    def normalise(self, grid):
        if self.grid_limits is not None:
            vmin, vmax = self.grid_limits
            grid = grid.copy()
            grid -= vmin
            if vmin == vmax:
                grid = grid * 0
            else:
                grid /= (vmax - vmin)
        else:
            n = _parse_n(self.normalize)
            grid, vmin, vmax = n(grid)
        return grid, vmin, vmax

    def get_binby(self):
        if self.z:
            z = self.z
            if ":" in z:
                z = z.split(":")[0]
            return [self.x, self.y, z]
        else:
            return [self.x, self.y]
Ejemplo n.º 27
0
class XarrayInterpolator(Interpolator):
    """Xarray interpolation Interpolation

    Attributes
    ----------
    {interpolator_attributes}

    fill_nan: bool
        Default is False. If True, nan values will be filled before interpolation.
    fill_value: float,str
        Default is None. The value that will be used to fill nan values. This can be a number, or "extrapolate", see `scipy.interpn`/`scipy/interp1d`
    kwargs: dict
        Default is {{"bounds_error": False}}. Additional values to pass to xarray's `interp` method.

    """

    dims_supported = ["lat", "lon", "alt", "time"]
    methods_supported = [
        "nearest",
        "linear",
        "bilinear",
        "quadratic",
        "cubic",
        "zero",
        "slinear",
        "next",
        "previous",
        "splinef2d",
    ]

    # defined at instantiation
    method = tl.Unicode(default_value="nearest")
    fill_value = tl.Union([tl.Unicode(), tl.Float()],
                          default_value=None,
                          allow_none=True)
    fill_nan = tl.Bool(False)

    kwargs = tl.Dict({"bounds_error": False})

    def __repr__(self):
        rep = super(XarrayInterpolator, self).__repr__()
        # rep += '\n\tspatial_tolerance: {}\n\ttime_tolerance: {}'.format(self.spatial_tolerance, self.time_tolerance)
        return rep

    @common_doc(COMMON_INTERPOLATOR_DOCS)
    def can_interpolate(self, udims, source_coordinates, eval_coordinates):
        """
        {interpolator_interpolate}
        """
        udims_subset = self._filter_udims_supported(udims)

        # confirm that udims are in both source and eval coordinates
        if self._dim_in(udims_subset, source_coordinates, unstacked=True):
            for d in source_coordinates.udims:  # Cannot handle stacked dimensions
                if source_coordinates.is_stacked(d):
                    return tuple()
            return udims_subset
        else:
            return tuple()

    @common_doc(COMMON_INTERPOLATOR_DOCS)
    def interpolate(self, udims, source_coordinates, source_data,
                    eval_coordinates, output_data):
        """
        {interpolator_interpolate}
        """
        coords = {}
        nn_coords = {}

        for d in udims:
            # Note: This interpolator cannot handle stacked source -- and this is handled in the can_interpolate function
            if source_coordinates[d].size == 1:
                # If the source only has a single coordinate, xarray will automatically throw an error asking for at least 2 coordinates
                # So, we prevent this. Main problem is that this won't respect any tolerances.
                new_dim = [dd for dd in eval_coordinates.dims if d in dd][0]
                nn_coords[d] = xr.DataArray(
                    eval_coordinates[d].coordinates,
                    dims=[new_dim],
                    coords=[eval_coordinates.xcoords[new_dim]],
                )
                continue
            if not source_coordinates.is_stacked(
                    d) and eval_coordinates.is_stacked(d):
                new_dim = [dd for dd in eval_coordinates.dims if d in dd][0]
                coords[d] = xr.DataArray(
                    eval_coordinates[d].coordinates,
                    dims=[new_dim],
                    coords=[eval_coordinates.xcoords[new_dim]])
            else:
                # TODO: Check dependent coordinates
                coords[d] = eval_coordinates[d].coordinates

        kwargs = self.kwargs.copy()
        kwargs.update({"fill_value": self.fill_value})

        coords["kwargs"] = kwargs

        if self.method == "bilinear":
            self.method = "linear"

        if self.fill_nan:
            for d in source_coordinates.dims:
                if not np.any(np.isnan(source_data)):
                    break
                # use_coordinate=False allows for interpolation when dimension is not monotonically increasing
                source_data = source_data.interpolate_na(method=self.method,
                                                         dim=d,
                                                         use_coordinate=False)

        if nn_coords:
            source_data = source_data.sel(method="nearest", **nn_coords)

        output_data = source_data.interp(method=self.method, **coords)

        return output_data.transpose(*eval_coordinates.dims)
Ejemplo n.º 28
0
class CatBoostModel(state.HasState):
    '''The CatBoost algorithm.

    This class provides an interface to the CatBoost aloritham.
    CatBoost is a fast, scalable, high performance Gradient Boosting on
    Decision Trees library, used for ranking, classification, regression and
    other machine learning tasks. For more information please visit
    https://github.com/catboost/catboost

    Example:

    >>> import vaex
    >>> import vaex.ml.catboost
    >>> df = vaex.ml.datasets.load_iris()
    >>> features = ['sepal_width', 'petal_length', 'sepal_length', 'petal_width']
    >>> df_train, df_test = vaex.ml.train_test_split(df)
    >>> params = {
        'leaf_estimation_method': 'Gradient',
        'learning_rate': 0.1,
        'max_depth': 3,
        'bootstrap_type': 'Bernoulli',
        'objective': 'MultiClass',
        'eval_metric': 'MultiClass',
        'subsample': 0.8,
        'random_state': 42,
        'verbose': 0}
    >>> booster = vaex.ml.catboost.CatBoostModel(features=features, num_boost_round=100, params=params)
    >>> booster.fit(df_train, 'class_')
    >>> df_train = booster.transform(df_train)
    >>> df_train.head(3)
    #    sepal_length    sepal_width    petal_length    petal_width    class_  catboost_prediction
    0             5.4            3               4.5            1.5         1  [0.00615039 0.98024259 0.01360702]
    1             4.8            3.4             1.6            0.2         0  [0.99034267 0.00526382 0.0043935 ]
    2             6.9            3.1             4.9            1.5         1  [0.00688241 0.95190908 0.04120851]
    >>> df_test = booster.transform(df_test)
    >>> df_test.head(3)
    #    sepal_length    sepal_width    petal_length    petal_width    class_  catboost_prediction
    0             5.9            3               4.2            1.5         1  [0.00464228 0.98883351 0.00652421]
    1             6.1            3               4.6            1.4         1  [0.00350424 0.9882139  0.00828186]
    2             6.6            2.9             4.6            1.3         1  [0.00325705 0.98891631 0.00782664]
    '''

    features = traitlets.List(
        traitlets.Unicode(),
        help='List of features to use when fitting the CatBoostModel.')
    num_boost_round = traitlets.CInt(default_value=None,
                                     allow_none=True,
                                     help='Number of boosting iterations.')
    params = traitlets.Dict(
        help=
        'A dictionary of parameters to be passed on to the CatBoostModel model.'
    )
    pool_params = traitlets.Dict(
        default_value={},
        help=
        'A dictionary of parameters to be passed to the Pool data object construction'
    )
    prediction_name = traitlets.Unicode(
        default_value='catboost_prediction',
        help='The name of the virtual column housing the predictions.')
    prediction_type = traitlets.Enum(
        values=['Probability', 'Class', 'RawFormulaVal'],
        default_value='Probability',
        help=
        'The form of the predictions. Can be "RawFormulaVal", "Probability" or "Class".'
    )

    def __call__(self, *args):
        data2d = np.vstack([arg.astype(np.float64) for arg in args]).T.copy()
        dmatrix = catboost.Pool(data2d, **self.pool_params)
        return self.booster.predict(dmatrix,
                                    prediction_type=self.prediction_type)

    def transform(self, df):
        '''Transform a DataFrame such that it contains the predictions of the CatBoostModel in form of a virtual column.

        :param df: A vaex DataFrame. It should have the same columns as the DataFrame used to train the model.

        :return copy: A shallow copy of the DataFrame that includes the CatBoostModel prediction as a virtual column.
        :rtype: DataFrame
        '''
        copy = df.copy()
        lazy_function = copy.add_function('catboost_prediction_function', self)
        expression = lazy_function(*self.features)
        copy.add_virtual_column(self.prediction_name, expression, unique=False)
        return copy

    def fit(self,
            df,
            target,
            evals=None,
            early_stopping_rounds=None,
            verbose_eval=None,
            plot=False,
            **kwargs):
        '''Fit the CatBoostModel model given a DataFrame.
        This method accepts all key word arguments for the catboost.train method.

        :param df: A vaex DataFrame containing the training features.
        :param target: The column name of the target variable.
        :param evals: A list of DataFrames to be evaluated during training.
        This allows user to watch performance on the validation sets.
        :param int early_stopping_rounds: Activates early stopping.
        :param bool verbose_eval: Requires at least one item in *evals*.
        If *verbose_eval* is True then the evaluation metric on the validation set is printed at each boosting stage.
        :param bool plot: if True, display an interactive widget in the Jupyter
        notebook of how the train and validation sets score on each boosting iteration.
        '''
        data = df[self.features].values
        target_data = df.evaluate(target)
        dtrain = catboost.Pool(data=data,
                               label=target_data,
                               **self.pool_params)
        if evals is not None:
            for i, item in enumerate(evals):
                data = item[self.features].values
                target_data = item.evaluate(target)
                evals[i] = catboost.Pool(data=data,
                                         label=target_data,
                                         **self.pool_params)

        # This does the actual training/fitting of the catboost model
        self.booster = catboost.train(
            params=self.params,
            dtrain=dtrain,
            num_boost_round=self.num_boost_round,
            evals=evals,
            early_stopping_rounds=early_stopping_rounds,
            verbose_eval=verbose_eval,
            plot=plot,
            **kwargs)

    def predict(self, df, **kwargs):
        '''Provided a vaex DataFrame, get an in-memory numpy array with the predictions from the CatBoostModel model.
        This method accepts the key word arguments of the predict method from catboost.

        :param df: a vaex DataFrame

        :returns: A in-memory numpy array containing the CatBoostModel predictions.
        :rtype: numpy.array
        '''
        data = df[self.features].values
        dmatrix = catboost.Pool(data, **self.pool_params)
        return self.booster.predict(dmatrix,
                                    prediction_type=self.prediction_type,
                                    **kwargs)

    def state_get(self):
        filename = tempfile.mktemp()
        self.booster.save_model(filename)
        with open(filename, 'rb') as f:
            data = f.read()
        return dict(tree_state=base64.encodebytes(data).decode('ascii'),
                    substate=super(CatBoostModel, self).state_get())

    def state_set(self, state, trusted=True):
        super(CatBoostModel, self).state_set(state['substate'])
        data = base64.decodebytes(state['tree_state'].encode('ascii'))
        filename = tempfile.mktemp()
        with open(filename, 'wb') as f:
            f.write(data)
        self.booster = catboost.CatBoost().load_model(fname=filename)
Ejemplo n.º 29
0
class Interpolator(tl.HasTraits):
    """Interpolation Method

    Attributes
    ----------
    {interpolator_attributes}

    """

    # defined by implementing Interpolator class
    methods_supported = tl.List(tl.Unicode())
    dims_supported = tl.List(tl.Unicode())

    # defined at instantiation
    method = tl.Unicode()

    # Next are used for optimizing the interpolation pipeline
    # If -1, it's cost is assume the same as a competing interpolator in the
    # stack, and the determination is made based on the number of DOF before
    # and after each interpolation step.
    # cost_func = tl.CFloat(-1)  # The rough cost FLOPS/DOF to do interpolation
    # cost_setup = tl.CFloat(-1)  # The rough cost FLOPS/DOF to set up the interpolator

    def __init__(self, **kwargs):

        # Call traitlets constructor
        super(Interpolator, self).__init__(**kwargs)

        # check method
        if len(self.methods_supported
               ) and self.method not in self.methods_supported:
            raise InterpolatorException(
                "Method {} is not supported by Interpolator {}".format(
                    self.method, self.name))
        self.init()

    def __repr__(self):
        return "{} ({})".format(self.name, self.method)

    @property
    def name(self):
        """
        Interpolator definition

        Returns
        -------
        str
            String name of interpolator.
        """
        return str(self.__class__.__name__)

    @property
    def definition(self):
        """
        Interpolator definition

        Returns
        -------
        str
            String name of interpolator.
        """
        return self.name

    def init(self):
        """
        Overwrite this method if a Interpolator needs to do any
        additional initialization after the standard initialization.
        """
        pass

    def _filter_udims_supported(self, udims):

        # find the intersection between dims_supported and udims, return tuple of intersection
        return tuple(set(self.dims_supported) & set(udims))

    def _dim_in(self, dim, *coords, **kwargs):
        """Verify the dim exists on coordinates

        Parameters
        ----------
        dim : str, list of str
            Dimension or list of dimensions to verify
        *coords :class:`podpac.Coordinates`
            coordinates to evaluate
        unstacked : bool, optional
            True if you want to compare dimensions in unstacked form, otherwise compare dimensions however
            they are defined on the DataSource. Defaults to False.

        Returns
        -------
        Boolean
            True if the dim is in all input coordinates
        """

        unstacked = kwargs.pop("unstacked", False)

        if isinstance(dim, six.string_types):
            dim = [dim]
        elif not isinstance(dim, (list, tuple)):
            raise ValueError(
                "`dim` input must be a str, list of str, or tuple of str")

        for coord in coords:
            for d in dim:
                if (unstacked
                        and d not in coord.udims) or (not unstacked
                                                      and d not in coord.dims):
                    return False

        return True

    def _loop_helper(self, func, interp_dims, udims, source_coordinates,
                     source_data, eval_coordinates, output_data, **kwargs):
        """In cases where the interpolator can only handle a limited number of dimensions, loop over the extra ones
        Parameters
        ----------
        func : callable
            The interpolation function that should be called on the data subset. Should have the following arguments:
            func(udims, source_coordinates, source_data, eval_coordinates, output_data)
        interp_dims: list(str)
            List of source dimensions that will be interpolator. The looped dimensions will be computed
        udims: list(str)
           The unstacked coordinates that this interpolator handles
        source_coordinates: podpac.Coordinates
            The coordinates of the source data
        eval_coordinates: podpac.Coordinates
            The user-requested or evaluated coordinates
        output_data: podpac.UnitsDataArray
            Container for the output of the interpolation function
        """
        loop_dims = [d for d in source_data.dims if d not in interp_dims]
        if not loop_dims:  # Do the actual interpolation
            return func(udims, source_coordinates, source_data,
                        eval_coordinates, output_data, **kwargs)

        dim = loop_dims[0]
        for i in output_data.coords[dim]:
            idx = {dim: i}

            if not i.isin(source_data.coords[dim]):
                # This case should have been properly handled in the interpolation_manager
                raise InterpolatorException("Unexpected interpolation error")

            output_data.loc[idx] = self._loop_helper(
                func, interp_dims, udims,
                source_coordinates.drop(dim), source_data.loc[idx],
                eval_coordinates.drop(dim), output_data.loc[idx], **kwargs)
        return output_data

    @common_doc(COMMON_INTERPOLATOR_DOCS)
    def can_select(self, udims, source_coordinates, eval_coordinates):
        """
        {interpolator_can_select}
        """
        if not (self.method in Selector.supported_methods):
            return tuple()

        udims_subset = self._filter_udims_supported(udims)
        return udims_subset

    @common_doc(COMMON_INTERPOLATOR_DOCS)
    def select_coordinates(self,
                           udims,
                           source_coordinates,
                           eval_coordinates,
                           index_type="numpy"):
        """
        {interpolator_select}
        """
        selector = Selector(method=self.method)
        return selector.select(source_coordinates,
                               eval_coordinates,
                               index_type=index_type)

    @common_doc(COMMON_INTERPOLATOR_DOCS)
    def can_interpolate(self, udims, source_coordinates, eval_coordinates):
        """
        {interpolator_can_interpolate}
        """
        return tuple()

    @common_doc(COMMON_INTERPOLATOR_DOCS)
    def interpolate(self, udims, source_coordinates, source_data,
                    eval_coordinates, output_data):
        """
        {interpolator_interpolate}
        """
        raise NotImplementedError
Ejemplo n.º 30
0
class WebVisualizer(ipywidgets.DOMWidget):
    """Open3D Web Visualizer based on WebRTC."""

    # Name of the widget view class in front-end.
    _view_name = traitlets.Unicode('WebVisualizerView').tag(sync=True)

    # Name of the widget model class in front-end.
    _model_name = traitlets.Unicode('WebVisualizerModel').tag(sync=True)

    # Name of the front-end module containing widget view.
    _view_module = traitlets.Unicode('open3d').tag(sync=True)

    # Name of the front-end module containing widget model.
    _model_module = traitlets.Unicode('open3d').tag(sync=True)

    # Version of the front-end module containing widget view.
    # @...@ is configured by cpp/pybind/make_python_package.cmake.
    _view_module_version = traitlets.Unicode(
        '~@PROJECT_VERSION_THREE_NUMBER@').tag(sync=True)
    # Version of the front-end module containing widget model.
    _model_module_version = traitlets.Unicode(
        '~@PROJECT_VERSION_THREE_NUMBER@').tag(sync=True)

    # Widget specific property. Widget properties are defined as traitlets. Any
    # property tagged with `sync=True` is automatically synced to the frontend
    # *any* time it changes in Python. It is synced back to Python from the
    # frontend *any* time the model is touched.
    window_uid = traitlets.Unicode("window_UNDEFINED",
                                   help="Window UID").tag(sync=True)

    # Two-way communication channels.
    pyjs_channel = traitlets.Unicode(
        "Empty pyjs_channel.",
        help="Python->JS message channel.").tag(sync=True)
    jspy_channel = traitlets.Unicode(
        "Empty jspy_channel.",
        help="JS->Python message channel.").tag(sync=True)

    def show(self):
        IPython.display.display(self)

    def _call_http_api(self, entry_point, query_string, data):
        return o3d.visualization.webrtc_server.call_http_api(
            entry_point, query_string, data)

    @traitlets.validate('window_uid')
    def _valid_window_uid(self, proposal):
        if proposal['value'][:7] != "window_":
            raise traitlets.TraitError('window_uid must be "window_xxx".')
        return proposal['value']

    @traitlets.observe('jspy_channel')
    def _on_jspy_channel(self, change):
        # self.result_map = {"0": "result0",
        #                    "1": "result1", ...};
        if not hasattr(self, "result_map"):
            self.result_map = dict()

        jspy_message = change["new"]
        try:
            jspy_requests = json.loads(jspy_message)

            for call_id, payload in jspy_requests.items():
                if "func" not in payload or payload["func"] != "call_http_api":
                    raise ValueError(f"Invalid jspy function: {jspy_requests}")
                if "args" not in payload or len(payload["args"]) != 3:
                    raise ValueError(
                        f"Invalid jspy function arguments: {jspy_requests}")

                # Check if already in result.
                if not call_id in self.result_map:
                    json_result = self._call_http_api(payload["args"][0],
                                                      payload["args"][1],
                                                      payload["args"][2])
                    self.result_map[call_id] = json_result
        except:
            print(
                f"jspy_message is not a function call, ignored: {jspy_message}"
            )
        else:
            self.pyjs_channel = json.dumps(self.result_map)