Пример #1
0
class Object3D(Asset):
  position = ktl.Vector3D(default_value=(0., 0., 0.))
  quaternion = ktl.Quaternion(default_value=(1., 0., 0., 0.))
  scale = ktl.Vector3D(default_value=(1., 1., 1.))

  up = tl.CaselessStrEnum(["X", "Y", "Z", "-X", "-Y", "-Z"], default_value="Y")
  front = tl.CaselessStrEnum(["X", "Y", "Z", "-X", "-Y", "-Z"], default_value="-Z")

  def look_at(self, target):
    direction = mathutils.Vector(target) - mathutils.Vector(self.position)
    self.quaternion = direction.to_track_quat(self.front.upper(), self.up.upper())
Пример #2
0
class MuonRingFitter(Component):
    """Different ring fit algorithms for muon rings
    """

    fit_method = traits.CaselessStrEnum(
        list(FIT_METHOD_BY_NAME.keys()),
        default_value=list(FIT_METHOD_BY_NAME.keys())[0],
    ).tag(config=True)

    def __call__(self, x, y, img, mask):
        """allows any fit to be called in form of
            MuonRingFitter(fit_method = "name of the fit")
        """
        fit_function = FIT_METHOD_BY_NAME[self.fit_method]
        radius, center_x, center_y = fit_function(x, y, img, mask)

        output = MuonRingParameter()
        output.ring_center_x = center_x
        output.ring_center_y = center_y
        output.ring_radius = radius
        output.ring_center_phi = np.arctan2(center_y, center_x)
        output.ring_center_distance = np.sqrt(center_x ** 2.0 + center_y ** 2.0)
        output.ring_fit_method = self.fit_method

        return output
Пример #3
0
class VizHistogramState(VizBaseState):
    x_expression = traitlets.Unicode()
    x_slice = traitlets.CInt(None, allow_none=True)
    type = traitlets.CaselessStrEnum(['count', 'min', 'max', 'mean'], default_value='count')
    aux = traitlets.Unicode(None, allow_none=True)
    groupby = traitlets.Unicode(None, allow_none=True)
    groupby_normalize = traitlets.Bool(False, allow_none=True)
    x_min = traitlets.CFloat(None, allow_none=True)
    x_max = traitlets.CFloat(None, allow_none=True)
    grid = traitlets.Any().tag(**serialize_numpy)
    grid_sliced = traitlets.Any().tag(**serialize_numpy)
    x_centers = traitlets.Any().tag(**serialize_numpy)
    x_shape = traitlets.CInt(None, allow_none=True)
    #centers = traitlets.Any()
    
    def __init__(self, ds, **kwargs):
        super(VizHistogramState, self).__init__(ds, **kwargs)
        self.observe(lambda x: self.signal_slice.emit(self), ['x_slice'])
        self.observe(lambda x: self.calculate_limits(), ['x_expression', 'type', 'aux'])
        # no need for recompute
        # self.observe(lambda x: self.calculate_grid(), ['groupby', 'shape', 'groupby_normalize'])
        # self.observe(lambda x: self.calculate_grid(), ['groupby', 'shape', 'groupby_normalize'])
        
        self.observe(lambda x: self._update_grid(), ['x_min', 'x_max', 'shape'])
        if self.x_min is None and self.x_max is None:
            self.calculate_limits()
        else:
            self._calculate_centers()

    def bin_parameters(self):
        yield self.x_expression, self.x_shape or self.shape, (self.x_min, self.x_max), self.x_slice

    def state_get(self):
        #         return {name: self.trait_metadata('grid', 'serialize', ident)(getattr(self, name) for name in self.trait_names()}
        state = {}
        for name in self.trait_names():
            serializer = self.trait_metadata(name, 'serialize', ident)
            value = serializer(getattr(self, name))
            state[name] = value
        return state

    def state_set(self, state):
        for name in self.trait_names():
            if name in state:
                deserializer = self.trait_metadata(name, 'deserialize', ident)
                value = deserializer(state[name])
                setattr(self, name, value)
                                                                      
    def calculate_limits(self):
        self._calculate_limits('x', 'x_expression')
        self.signal_regrid.emit(None) # TODO this is also called in the ctor, unnec work
    
    def limits_changed(self, change):
        self.signal_regrid.emit(None) # TODO this is also called in the ctor, unnec work

    @vaex.jupyter.debounced()
    def _update_grid(self):
        self._calculate_centers()
        self.signal_regrid.emit(None)
Пример #4
0
class FileOutput(Output):
    """ Output a file to the local filesystem.

    Attributes
    ----------
    format : TYPE
        Description
    outdir : TYPE
        Description
    """

    outdir = tl.Unicode()
    format = tl.CaselessStrEnum(values=['pickle', 'geotif', 'png'],
                                default_value='pickle').tag(attr=True)
    mode = tl.Unicode(default_value="file").tag(attr=True)

    _path = tl.Unicode(allow_none=True, default_value=None)

    def __init__(self, node, name, format=None, outdir=None, mode=None):
        kwargs = {}
        if format is not None:
            kwargs['format'] = format
        if outdir is not None:
            kwargs['outdir'] = outdir
        if mode is not None:
            kwargs['mode'] = mode
        super(FileOutput, self).__init__(node=node, name=name, **kwargs)

    @property
    def path(self):
        return self._path

    # TODO: docstring?
    def write(self, output, coordinates):
        filename = '%s_%s_%s' % (self.name, self.node.hash, coordinates.hash)
        path = os.path.join(self.outdir, filename)

        if self.format == 'pickle':
            path = '%s.pkl' % path
            with open(path, 'wb') as f:
                cPickle.dump(output, f)
        elif self.format == 'png':
            raise NotImplementedError("format '%s' not yet implemented" %
                                      self.format)
        elif self.format == 'geotif':
            raise NotImplementedError("format '%s' not yet implemented" %
                                      self.format)

        self._path = path
Пример #5
0
class SmtpSpec(MailSpec):
    """Parameters and methods for SMTP(sending emails)."""

    login = trt.CaselessStrEnum(
        'login simple'.split(),
        default_value=None,
        allow_none=True,
        help=
        """Which SMTP mechanism to use to authenticate: [ login | simple | <None> ]. """
    ).tag(config=True)

    kwds = trt.Dict(
        help=
        """Any key-value pairs passed to the SMTP/IMAP mail-client libraries."""
    ).tag(config=True)
Пример #6
0
class ImageOutput(Output):
    """Output an image in RAM

    Attributes
    ----------
    format : TYPE
        Description
    image : TYPE
        Description
    vmax : TYPE
        Description
    vmin : TYPE
        Description
    """

    format = tl.CaselessStrEnum(values=['png'],
                                default_value='png').tag(attr=True)
    mode = tl.Unicode(default_value="image").tag(attr=True)
    vmin = tl.CFloat(allow_none=True, default_value=np.nan).tag(attr=True)
    vmax = tl.CFloat(allow_none=True, default_value=np.nan).tag(attr=True)
    image = tl.Bytes(allow_none=True, default_value=None)

    def __init__(self,
                 node,
                 name,
                 format=None,
                 mode=None,
                 vmin=None,
                 vmax=None):
        kwargs = {}
        if format is not None:
            kwargs['format'] = format
        if mode is not None:
            kwargs['mode'] = mode
        if vmin is not None:
            kwargs['vmin'] = vmin
        if vmax is not None:
            kwargs['vmax'] = vmax

        super(ImageOutput, self).__init__(node=node, name=name, **kwargs)

    # TODO: docstring?
    def write(self, output, coordinates):
        self.image = get_image(output,
                               format=self.format,
                               vmin=self.vmin,
                               vmax=self.vmax)
Пример #7
0
class Mesh(widgets.DOMWidget):
    _view_name = Unicode('MeshView').tag(sync=True)
    _view_module = Unicode('ipyvolume').tag(sync=True)
    _model_name = Unicode('MeshModel').tag(sync=True)
    _model_module = Unicode('ipyvolume').tag(sync=True)
    _view_module_version = Unicode(semver_range_frontend).tag(sync=True)
    _model_module_version = Unicode(semver_range_frontend).tag(sync=True)
    x = Array(default_value=None).tag(sync=True,
                                      **array_sequence_serialization)
    y = Array(default_value=None).tag(sync=True,
                                      **array_sequence_serialization)
    z = Array(default_value=None).tag(sync=True,
                                      **array_sequence_serialization)
    u = Array(default_value=None,
              allow_none=True).tag(sync=True, **array_sequence_serialization)
    v = Array(default_value=None,
              allow_none=True).tag(sync=True, **array_sequence_serialization)
    triangles = Array(default_value=None,
                      allow_none=True).tag(sync=True, **array_serialization)
    lines = Array(default_value=None,
                  allow_none=True).tag(sync=True, **array_serialization)
    texture = traitlets.Union([
        traitlets.Instance(ipywebrtc.MediaStream),
        Unicode(),
        traitlets.List(Unicode, [], allow_none=True),
        Image(default_value=None, allow_none=True),
        traitlets.List(Image(default_value=None, allow_none=True))
    ]).tag(sync=True, **texture_serialization)

    #    selected = Array(default_value=None, allow_none=True).tag(sync=True, **array_sequence_serialization)
    sequence_index = Integer(default_value=0).tag(sync=True)
    color = Array(default_value="red",
                  allow_none=True).tag(sync=True, **color_serialization)
    #    color_selected = traitlets.Union([Array(default_value=None, allow_none=True).tag(sync=True, **color_serialization),
    #                                     Unicode().tag(sync=True)],
    #                                     default_value="green").tag(sync=True)
    #    geo = traitlets.Unicode('diamond').tag(sync=True)
    visible = traitlets.CBool(default_value=True).tag(sync=True)
    visible_lines = traitlets.CBool(default_value=True).tag(sync=True)
    visible_faces = traitlets.CBool(default_value=True).tag(sync=True)

    side = traitlets.CaselessStrEnum(['front', 'back', 'both'],
                                     'both').tag(sync=True)
Пример #8
0
class VizBase2dState(VizBaseState):
    x_expression = traitlets.Unicode()
    y_expression = traitlets.Unicode()
    x_slice = traitlets.CInt(None, allow_none=True)
    y_slice = traitlets.CInt(None, allow_none=True)
    type = traitlets.CaselessStrEnum(['count', 'min', 'max', 'mean'],
                                     default_value='count')
    aux = traitlets.Unicode(None, allow_none=True)
    groupby = traitlets.Unicode(None, allow_none=True)
    x_shape = traitlets.CInt(None, allow_none=True)
    y_shape = traitlets.CInt(None, allow_none=True)

    x_min = traitlets.CFloat()
    x_max = traitlets.CFloat()
    y_min = traitlets.CFloat()
    y_max = traitlets.CFloat()

    def __init__(self, ds, **kwargs):
        super(VizBase2dState, self).__init__(ds, **kwargs)
        self.observe(lambda x: self.calculate_limits(),
                     ['x_expression', 'y_expression', 'type', 'aux'])
        self.observe(lambda x: self.signal_slice.emit(self),
                     ['x_slice', 'y_slice'])
        # no need for recompute
        #self.observe(lambda x: self.calculate_grid(), ['groupby', 'shape', 'groupby_normalize'])
        self.observe(self.limits_changed, ['x_min', 'x_max', 'y_min', 'y_max'])
        self.calculate_limits()

    def bin_parameters(self):
        yield self.x_expression, self.x_shape or self.shape, (
            self.x_min, self.x_max), self.x_slice
        yield self.y_expression, self.y_shape or self.shape, (
            self.y_min, self.y_max), self.y_slice

    def calculate_limits(self):
        self._calculate_limits('x', 'x_expression')
        self._calculate_limits('y', 'y_expression')
        self.signal_regrid.emit(self)

    def limits_changed(self, change):
        self._calculate_centers()
        self.signal_regrid.emit(self)
Пример #9
0
class WCSRaw(DataSource):
    """
    Access data from a WCS source.

    Attributes
    ----------
    source : str
        WCS server url
    layer : str
        layer name (required)
    version : str
        WCS version, passed through to all requests (default '1.0.0')
    format : str
        Data format, passed through to the GetCoverage requests (default 'geotiff')
    crs : str
        coordinate reference system, passed through to the GetCoverage requests (default 'EPSG:4326')
    interpolation : str
        Interpolation, passed through to the GetCoverage requests.
    max_size : int
        maximum request size, optional.
        If provided, the coordinates will be tiled into multiple requests.
    allow_mock_client : bool
        Default is False. If True, a mock client will be used to make WCS requests. This allows returns
        from servers with only partial WCS implementations.
    username : str
        Username for servers that require authentication
    password : str
        Password for servers that require authentication

    See Also
    --------
    WCS : WCS datasource with podpac interpolation.
    """

    source = tl.Unicode().tag(attr=True, required=True)
    layer = tl.Unicode().tag(attr=True, required=True)
    version = tl.Unicode(default_value="1.0.0").tag(attr=True)
    interpolation = InterpolationTrait(default_value=None,
                                       allow_none=True).tag(attr=True)
    allow_mock_client = tl.Bool(False).tag(attr=True)
    username = tl.Unicode(allow_none=True)
    password = tl.Unicode(allow_none=True)

    format = tl.CaselessStrEnum(["geotiff", "geotiff_byte"],
                                default_value="geotiff")
    crs = tl.Unicode(default_value="EPSG:4326")
    max_size = tl.Long(default_value=None, allow_none=True)
    wcs_kwargs = tl.Dict(
        help="Additional query parameters sent to the WCS server")

    _repr_keys = ["source", "layer"]

    _requested_coordinates = tl.Instance(Coordinates, allow_none=True)
    _evaluated_coordinates = tl.Instance(Coordinates)
    coordinate_index_type = "slice"

    @property
    def auth(self):
        if self.username and self.password:
            return owslib_util.Authentication(username=self.username,
                                              password=self.password)
        return None

    @cached_property
    def client(self):
        try:
            return owslib_wcs.WebCoverageService(self.source,
                                                 version=self.version,
                                                 auth=self.auth)
        except Exception as e:
            if self.allow_mock_client:
                logger.warning(
                    "The OWSLIB Client could not be used. Server endpoint likely does not implement GetCapabilities"
                    "requests. Using Mock client instead. Error was {}".format(
                        e))
                return MockWCSClient(source=self.source,
                                     version=self.version,
                                     auth=self.auth)
            else:
                raise e

    def get_coordinates(self):
        """
        Get the full WCS grid.
        """

        metadata = self.client.contents[self.layer]

        # coordinates
        bbox = metadata.boundingBoxWGS84
        crs = "EPSG:4326"
        logging.debug("WCS available boundingboxes: {}".format(
            metadata.boundingboxes))
        for bboxes in metadata.boundingboxes:
            if bboxes["nativeSrs"] == self.crs:
                bbox = bboxes["bbox"]
                crs = self.crs
                break

        low = metadata.grid.lowlimits
        high = metadata.grid.highlimits
        xsize = int(high[0]) - int(low[0])
        ysize = int(high[1]) - int(low[1])

        # Based on https://www.ctps.org/geoserver/web/wicket/bookmarkable/org.geoserver.wcs.web.demo.WCSRequestBuilder;jsessionid=9E2AA99F95410C694D05BA609F25527C?0
        # The above link points to a geoserver implementation, which is the reference implementation.
        # WCS version 1.0.0 always has order lon/lat while version 1.1.1 actually follows the CRS
        if self.version == "1.0.0":
            rbbox = {
                "lat": [bbox[1], bbox[3], ysize],
                "lon": [bbox[0], bbox[2], xsize]
            }
        else:
            rbbox = resolve_bbox_order(bbox, crs, (xsize, ysize))

        coords = []
        coords.append(
            UniformCoordinates1d(rbbox["lat"][0],
                                 rbbox["lat"][1],
                                 size=rbbox["lat"][2],
                                 name="lat"))
        coords.append(
            UniformCoordinates1d(rbbox["lon"][0],
                                 rbbox["lon"][1],
                                 size=rbbox["lon"][2],
                                 name="lon"))

        if metadata.timepositions:
            coords.append(
                ArrayCoordinates1d(metadata.timepositions, name="time"))

        if metadata.timelimits:
            raise NotImplementedError("TODO")

        return Coordinates(coords, crs=crs)

    def _eval(self, coordinates, output=None, _selector=None):
        """Evaluates this node using the supplied coordinates.

        This method intercepts the DataSource._eval method in order to use the requested coordinates directly when
        they are a uniform grid.

        Parameters
        ----------
        coordinates : :class:`podpac.Coordinates`
            {requested_coordinates}

            An exception is raised if the requested coordinates are missing dimensions in the DataSource.
            Extra dimensions in the requested coordinates are dropped.
        output : :class:`podpac.UnitsDataArray`, optional
            {eval_output}
        _selector: callable(coordinates, request_coordinates)
            {eval_selector}

        Returns
        -------
        {eval_return}

        Raises
        ------
        ValueError
            Cannot evaluate these coordinates
        """
        # The mock client cannot figure out the real coordinates, so just duplicate the requested coordinates
        if isinstance(self.client, MockWCSClient):
            if not coordinates["lat"].is_uniform or not coordinates[
                    "lon"].is_uniform:
                raise NotImplementedError(
                    "When using the Mock WCS client, the requested coordinates need to be uniform."
                )
            self.set_trait("_coordinates", coordinates)
            self.set_trait("crs", coordinates.crs)

        # remove extra dimensions
        extra = [
            c.name for c in coordinates.values()
            if (isinstance(c, Coordinates1d) and c.name not in self.coordinates
                .udims) or (isinstance(c, StackedCoordinates) and all(
                    dim not in self.coordinates.udims for dim in c.dims))
        ]
        coordinates = coordinates.drop(extra)

        # the datasource does do this, but we need to do it here to correctly select the correct case
        if self.coordinates.crs.lower() != coordinates.crs.lower():
            coordinates = coordinates.transform(self.coordinates.crs)

        # for a uniform grid, use the requested coordinates (the WCS server will interpolate)
        if (("lat" in coordinates.dims and "lon" in coordinates.dims) and
            (coordinates["lat"].is_uniform or coordinates["lat"].size == 1) and
            (coordinates["lon"].is_uniform or coordinates["lon"].size == 1)):

            def selector(rsc, coordinates, index_type=None):
                return coordinates, None

            return super()._eval(coordinates,
                                 output=output,
                                 _selector=selector)

        # for uniform stacked, unstack to use the requested coordinates (the WCS server will interpolate)
        if (("lat" in coordinates.udims and coordinates.is_stacked("lat")) and
            ("lon" in coordinates.udims and coordinates.is_stacked("lon")) and
            (coordinates["lat"].is_uniform or coordinates["lat"].size == 1) and
            (coordinates["lon"].is_uniform or coordinates["lon"].size == 1)):

            def selector(rsc, coordinates, index_type=None):
                unstacked = coordinates.unstack()
                unstacked = unstacked.drop(
                    "alt", ignore_missing=True)  # if lat_lon_alt
                return unstacked, None

            udata = super()._eval(coordinates, output=None, _selector=selector)
            data = udata.data.diagonal()  # get just the stacked data
            if output is None:
                output = self.create_output_array(coordinates, data=data)
            else:
                output.data[:] = data
            return output

        # otherwise, pass-through (podpac will select and interpolate)
        return super()._eval(coordinates, output=output, _selector=_selector)

    def _get_data(self, coordinates, coordinates_index):
        """{get_data}"""

        # transpose the coordinates to match the response data
        if "time" in coordinates:
            coordinates = coordinates.transpose("time", "lat", "lon")
        else:
            coordinates = coordinates.transpose("lat", "lon")

        # determine the chunk size (if applicable)
        if self.max_size is not None:
            shape = []
            s = 1
            for n in coordinates.shape:
                r = self.max_size // s
                if r == 0:
                    shape.append(1)
                elif r < n:
                    shape.append(r)
                else:
                    shape.append(n)
                s *= n
            shape = tuple(shape)
        else:
            shape = coordinates.shape

        # request each chunk and composite the data
        output = self.create_output_array(coordinates)
        for i, (chunk, slc) in enumerate(
                coordinates.iterchunks(shape, return_slices=True)):
            output[slc] = self._get_chunk(chunk)

        return output

    def _get_chunk(self, coordinates):
        if coordinates["lon"].size == 1:
            w = coordinates["lon"].coordinates[0]
            e = coordinates["lon"].coordinates[0]
        else:
            w = coordinates["lon"].start - coordinates["lon"].step / 2.0
            e = coordinates["lon"].stop + coordinates["lon"].step / 2.0

        if coordinates["lat"].size == 1:
            s = coordinates["lat"].coordinates[0]
            n = coordinates["lat"].coordinates[0]
        else:
            s = coordinates["lat"].start - coordinates["lat"].step / 2.0
            n = coordinates["lat"].stop + coordinates["lat"].step / 2.0

        width = coordinates["lon"].size
        height = coordinates["lat"].size

        kwargs = self.wcs_kwargs.copy()

        if "time" in coordinates:
            kwargs["time"] = coordinates["time"].coordinates.astype(
                str).tolist()

        if isinstance(self.interpolation, str):
            kwargs["interpolation"] = self.interpolation

        logger.info(
            "WCS GetCoverage (source=%s, layer=%s, bbox=%s, shape=%s, time=%s)"
            % (self.source, self.layer, (w, n, e, s),
               (width, height), kwargs.get("time")))

        crs = pyproj.CRS(coordinates.crs)
        bbox = (min(w, e), min(s, n), max(e, w), max(n, s))
        # Based on the spec I need the following line, but
        # all my tests on other servers suggests I don't need this...
        # if crs.axis_info[0].direction == "north":
        #     bbox = (min(s, n), min(w, e), max(n, s), max(e, w))

        response = self.client.getCoverage(identifier=self.layer,
                                           bbox=bbox,
                                           width=width,
                                           height=height,
                                           crs=self.crs,
                                           format=self.format,
                                           version=self.version,
                                           **kwargs)
        content = response.read()

        # check for errors
        xml = bs4.BeautifulSoup(content, "lxml")
        error = xml.find("serviceexception")
        if error:
            raise WCSError(error.text)

        # get data using rasterio
        with rasterio.MemoryFile() as mf:
            mf.write(content)
            try:
                dataset = mf.open(driver="GTiff")
            except rasterio.RasterioIOError:
                raise WCSError("Could not read file with contents:", content)

        if "time" in coordinates and coordinates["time"].size > 1:
            # this should be easy to do, I'm just not sure how the data comes back.
            # is each time in a different band?
            raise NotImplementedError("TODO")

        data = dataset.read().astype(float).squeeze()

        # Need to fix the order of the data in the case of multiple bands
        if len(data.shape) == 3:
            data = data.transpose((1, 2, 0))

        # Need to fix the data order. The request and response order is always the same in WCS, but not in PODPAC
        if n > s:  # By default it returns the data upside down, so this is backwards
            data = data[::-1]
        if e < w:
            data = data[:, ::-1]

        return data

    @classmethod
    def get_layers(cls, source=None):
        if source is None:
            source = cls.source
        client = owslib_wcs.WebCoverageService(source)
        return list(client.contents)
Пример #10
0
class GroupReduce(Algorithm):
    """
    Group a time-dependent source node and then compute a statistic for each result.
    
    Attributes
    ----------
    custom_reduce_fn : function
        required if reduce_fn is 'custom'.
    groupby : str
        datetime sub-accessor. Currently 'dayofyear' is the enabled option.
    reduce_fn : str
        builtin xarray groupby reduce function, or 'custom'.
    source : podpac.Node
        Source node
    """

    source = tl.Instance(Node)
    coordinates_source = tl.Instance(Node, allow_none=True)

    # see https://github.com/pydata/xarray/blob/eeb109d9181c84dfb93356c5f14045d839ee64cb/xarray/core/accessors.py#L61
    groupby = tl.CaselessStrEnum(['dayofyear'])  # could add season, month, etc

    reduce_fn = tl.CaselessStrEnum([
        'all', 'any', 'count', 'max', 'mean', 'median', 'min', 'prod', 'std',
        'sum', 'var', 'custom'
    ])
    custom_reduce_fn = tl.Any()

    _source_coordinates = tl.Instance(Coordinates)

    @tl.default('coordinates_source')
    def _default_coordinates_source(self):
        return self.source

    def _get_source_coordinates(self, requested_coordinates):
        # get available time coordinates
        # TODO do these two checks during node initialization
        available_coordinates = self.coordinates_source.find_coordinates()
        if len(available_coordinates) != 1:
            raise ValueError(
                "Cannot evaluate this node; too many available coordinates")
        avail_coords = available_coordinates[0]
        if 'time' not in avail_coords.udims:
            raise ValueError(
                "GroupReduce coordinates source node must be time-dependent")

        # intersect grouped time coordinates using groupby DatetimeAccessor
        avail_time = xr.DataArray(avail_coords.coords['time'])
        eval_time = xr.DataArray(requested_coordinates.coords['time'])
        N = getattr(avail_time.dt, self.groupby)
        E = getattr(eval_time.dt, self.groupby)
        native_time_mask = np.in1d(N, E)

        # use requested spatial coordinates and filtered available times
        coords = Coordinates(time=avail_time.data[native_time_mask],
                             lat=requested_coordinates['lat'],
                             lon=requested_coordinates['lon'],
                             order=('time', 'lat', 'lon'))

        return coords

    @common_doc(COMMON_DOC)
    @node_eval
    def eval(self, coordinates, output=None):
        """Evaluates this nodes using the supplied coordinates. 
        
        Parameters
        ----------
        coordinates : podpac.Coordinates
            {requested_coordinates}
        output : podpac.UnitsDataArray, optional
            {eval_output}
        
        Returns
        -------
        {eval_return}
        
        Raises
        ------
        ValueError
            If source it not time-depended (required by this node).
        """

        self._source_coordinates = self._get_source_coordinates(coordinates)

        if output is None:
            output = self.create_output_array(coordinates)

        source_output = self.source.eval(self._source_coordinates)

        # group
        grouped = source_output.groupby('time.%s' % self.groupby)

        # reduce
        if self.reduce_fn is 'custom':
            out = grouped.apply(self.custom_reduce_fn, 'time')
        else:
            # standard, e.g. grouped.median('time')
            out = getattr(grouped, self.reduce_fn)('time')

        # map
        eval_time = xr.DataArray(coordinates.coords['time'])
        E = getattr(eval_time.dt, self.groupby)
        out = out.sel(**{self.groupby: E}).rename({self.groupby: 'time'})
        output[:] = out.transpose(*output.dims).data

        return output

    def base_ref(self):
        """
        Default pipeline node reference/name in pipeline node definitions
        
        Returns
        -------
        str
            Default pipeline node reference/name in pipeline node definitions
        """
        return '%s.%s.%s' % (self.source.base_ref, self.groupby,
                             self.reduce_fn)
Пример #11
0
class ResampleReduce(UnaryAlgorithm):
    """
    Resample a time-dependent source node using a statistical operation to achieve the result.

    Attributes
    ----------
    custom_reduce_fn : function
        required if reduce_fn is 'custom'.
    resample : str
        datetime sub-accessor. Currently 'dayofyear' is the enabled option.
    reduce_fn : str
        builtin xarray groupby reduce function, or 'custom'.
    source : podpac.Node
        Source node
    """

    _repr_keys = ["source", "resample", "reduce_fn"]
    coordinates_source = NodeTrait(allow_none=True).tag(attr=True)

    # see https://github.com/pydata/xarray/blob/eeb109d9181c84dfb93356c5f14045d839ee64cb/xarray/core/accessors.py#L61
    resample = tl.Unicode().tag(attr=True)
    reduce_fn = tl.CaselessStrEnum(_REDUCE_FUNCTIONS).tag(attr=True)
    custom_reduce_fn = tl.Any(allow_none=True, default_value=None).tag(attr=True)

    _source_coordinates = tl.Instance(Coordinates)

    @tl.default("coordinates_source")
    def _default_coordinates_source(self):
        return self.source

    @common_doc(COMMON_DOC)
    def _eval(self, coordinates, output=None, _selector=None):
        """Evaluates this nodes using the supplied coordinates.

        Parameters
        ----------
        coordinates : podpac.Coordinates
            {requested_coordinates}
        output : podpac.UnitsDataArray, optional
            {eval_output}
        _selector: callable(coordinates, request_coordinates)
            {eval_selector}

        Returns
        -------
        {eval_return}

        Raises
        ------
        ValueError
            If source it not time-dependent (required by this node).
        """

        source_output = self.source.eval(coordinates, _selector=_selector)

        # group
        grouped = source_output.resample(time=self.resample)

        # reduce
        if self.reduce_fn == "custom":
            out = grouped.reduce(self.custom_reduce_fn)
        else:
            # standard, e.g. grouped.median('time')
            out = getattr(grouped, self.reduce_fn)()

        if output is None:
            output = podpac.UnitsDataArray(out)
            output.attrs = source_output.attrs
        else:
            output.data[:] = out.data[:]

        ## map
        # eval_time = xr.DataArray(coordinates.coords["time"])
        # E = getattr(eval_time.dt, self.groupby)
        # out = out.sel(**{self.groupby: E}).rename({self.groupby: "time"})
        # output[:] = out.transpose(*output.dims).data

        return output

    @property
    def base_ref(self):
        """
        Default node reference/name in node definitions

        Returns
        -------
        str
            Default node reference/name in node definitions
        """
        return "%s.%s.%s" % (self.source.base_ref, self.resample, self.reduce_fn)
Пример #12
0
class WCSBase(DataSource):
    """
    Access data from a WCS source.

    Attributes
    ----------
    source : str
        WCS server url
    layer : str
        layer name (required)
    version : str
        WCS version, passed through to all requests (default '1.0.0')
    format : str
        Data format, passed through to the GetCoverage requests (default 'geotiff')
    crs : str
        coordinate reference system, passed through to the GetCoverage requests (default 'EPSG:4326')
    interpolation : str
        Interpolation, passed through to the GetCoverage requests.
    max_size : int
        maximum request size, optional.
        If provided, the coordinates will be tiled into multiple requests.
    """

    source = tl.Unicode().tag(attr=True)
    layer = tl.Unicode().tag(attr=True)
    version = tl.Unicode(default_value="1.0.0").tag(attr=True)
    interpolation = tl.Unicode(default_value=None,
                               allow_none=True).tag(attr=True)

    format = tl.CaselessStrEnum(["geotiff", "geotiff_byte"],
                                default_value="geotiff")
    crs = tl.Unicode(default_value="EPSG:4326")
    max_size = tl.Long(default_value=None, allow_none=True)

    _repr_keys = ["source", "layer"]

    _requested_coordinates = tl.Instance(Coordinates, allow_none=True)
    _evaluated_coordinates = tl.Instance(Coordinates)

    @cached_property
    def client(self):
        return owslib_wcs.WebCoverageService(self.source, version=self.version)

    def get_coordinates(self):
        """
        Get the full WCS grid.
        """

        metadata = self.client.contents[self.layer]

        # TODO select correct boundingbox by crs

        # coordinates
        w, s, e, n = metadata.boundingBoxWGS84
        low = metadata.grid.lowlimits
        high = metadata.grid.highlimits
        xsize = int(high[0]) - int(low[0])
        ysize = int(high[1]) - int(low[1])

        coords = []
        coords.append(UniformCoordinates1d(s, n, size=ysize, name="lat"))
        coords.append(UniformCoordinates1d(w, e, size=xsize, name="lon"))

        if metadata.timepositions:
            coords.append(
                ArrayCoordinates1d(metadata.timepositions, name="time"))

        if metadata.timelimits:
            raise NotImplementedError("TODO")

        return Coordinates(coords, crs=self.crs)

    def _eval(self, coordinates, output=None, _selector=None):
        """Evaluates this node using the supplied coordinates.

        This method intercepts the DataSource._eval method in order to use the requested coordinates directly when
        they are a uniform grid.

        Parameters
        ----------
        coordinates : :class:`podpac.Coordinates`
            {requested_coordinates}

            An exception is raised if the requested coordinates are missing dimensions in the DataSource.
            Extra dimensions in the requested coordinates are dropped.
        output : :class:`podpac.UnitsDataArray`, optional
            {eval_output}
        _selector: callable(coordinates, request_coordinates)
            {eval_selector}

        Returns
        -------
        {eval_return}

        Raises
        ------
        ValueError
            Cannot evaluate these coordinates
        """

        # the datasource does do this, but we need to do it here to correctly select the correct case
        if self.coordinates.crs.lower() != coordinates.crs.lower():
            coordinates = coordinates.transform(self.coordinates.crs)

        # for a uniform grid, use the requested coordinates (the WCS server will interpolate)
        if (("lat" in coordinates.dims and "lon" in coordinates.dims) and
            (coordinates["lat"].is_uniform or coordinates["lat"].size == 1) and
            (coordinates["lon"].is_uniform or coordinates["lon"].size == 1)):

            def selector(rsc, coordinates, index_type=None):
                return coordinates, None

            return super()._eval(coordinates,
                                 output=output,
                                 _selector=selector)

        # for uniform stacked, unstack to use the requested coordinates (the WCS server will interpolate)
        if (("lat" in coordinates.udims and coordinates.is_stacked("lat")) and
            ("lon" in coordinates.udims and coordinates.is_stacked("lon")) and
            (coordinates["lat"].is_uniform or coordinates["lat"].size == 1) and
            (coordinates["lon"].is_uniform or coordinates["lon"].size == 1)):

            def selector(rsc, coordinates, index_type=None):
                unstacked = coordinates.unstack()
                unstacked = unstacked.drop(
                    "alt", ignore_missing=True)  # if lat_lon_alt
                return unstacked, None

            udata = super()._eval(coordinates, output=None, _selector=selector)
            data = udata.data.diagonal()  # get just the stacked data
            if output is None:
                output = self.create_output_array(coordinates, data=data)
            else:
                output.data[:] = data
            return output

        # otherwise, pass-through (podpac will select and interpolate)
        return super()._eval(coordinates, output=output, _selector=_selector)

    def _get_data(self, coordinates, coordinates_index):
        """{get_data}"""

        # transpose the coordinates to match the response data
        if "time" in coordinates:
            coordinates = coordinates.transpose("time", "lat", "lon")
        else:
            coordinates = coordinates.transpose("lat", "lon")

        # determine the chunk size (if applicable)
        if self.max_size is not None:
            shape = []
            s = 1
            for n in coordinates.shape:
                r = self.max_size // s
                if r == 0:
                    shape.append(1)
                elif r < n:
                    shape.append(r)
                else:
                    shape.append(n)
                s *= n
            shape = tuple(shape)
        else:
            shape = coordinates.shape

        # request each chunk and composite the data
        output = self.create_output_array(coordinates)
        for i, (chunk, slc) in enumerate(
                coordinates.iterchunks(shape, return_slices=True)):
            output[slc] = self._get_chunk(chunk)

        return output

    def _get_chunk(self, coordinates):
        if coordinates["lon"].size == 1:
            w = coordinates["lon"].coordinates[0]
            e = coordinates["lon"].coordinates[0]
        else:
            w = coordinates["lon"].start - coordinates["lon"].step / 2.0
            e = coordinates["lon"].stop + coordinates["lon"].step / 2.0

        if coordinates["lat"].size == 1:
            s = coordinates["lat"].coordinates[0]
            n = coordinates["lat"].coordinates[0]
        else:
            s = coordinates["lat"].start - coordinates["lat"].step / 2.0
            n = coordinates["lat"].stop + coordinates["lat"].step / 2.0

        width = coordinates["lon"].size
        height = coordinates["lat"].size

        kwargs = {}

        if "time" in coordinates:
            kwargs["time"] = coordinates["time"].coordinates.astype(
                str).tolist()

        if isinstance(self.interpolation, str):
            kwargs["interpolation"] = self.interpolation

        logger.info(
            "WCS GetCoverage (source=%s, layer=%s, bbox=%s, shape=%s)" %
            (self.source, self.layer, (w, n, e, s), (width, height)))

        response = self.client.getCoverage(identifier=self.layer,
                                           bbox=(w, n, e, s),
                                           width=width,
                                           height=height,
                                           crs=self.crs,
                                           format=self.format,
                                           version=self.version,
                                           **kwargs)
        content = response.read()

        # check for errors
        xml = bs4.BeautifulSoup(content, "lxml")
        error = xml.find("serviceexception")
        if error:
            raise WCSError(error.text)

        # get data using rasterio
        with rasterio.MemoryFile() as mf:
            mf.write(content)
            dataset = mf.open(driver="GTiff")

        if "time" in coordinates and coordinates["time"].size > 1:
            # this should be easy to do, I'm just not sure how the data comes back.
            # is each time in a different band?
            raise NotImplementedError("TODO")

        data = dataset.read(1).astype(float)
        return data

    @classmethod
    def get_layers(cls, source=None):
        if source is None:
            source = cls.source
        client = owslib_wcs.WebCoverageService(source)
        return list(client.contents)
Пример #13
0
class Shader(traitlets.HasTraits):
    """
    Creates a shader from source

    Parameters
    ----------

    source : str
        This can either be a string containing a full source of a shader,
        an absolute path to a source file or a filename of a shader
        residing in the ./shaders/ directory.

    """

    _shader = None
    source = traitlets.Any()
    shader_name = traitlets.CUnicode()
    info = traitlets.CUnicode()
    shader_type = traitlets.CaselessStrEnum(("vertex", "fragment", "geometry"))
    blend_func = traitlets.Tuple(GLValue(),
                                 GLValue(),
                                 default_value=("src alpha", "dst alpha"))
    blend_equation = GLValue("func add")
    depth_test = GLValue("always")

    use_separate_blend = traitlets.Bool(False)
    blend_equation_separate = traitlets.Tuple(GLValue(),
                                              GLValue(),
                                              default_value=("none", "none"))
    blend_func_separate = traitlets.Tuple(
        GLValue(),
        GLValue(),
        GLValue(),
        GLValue(),
        default_value=("none", "none", "none", "none"),
    )

    def _get_source(self, source):
        if ";" in source:
            # This is probably safe, right?  Enh, probably.
            return source
        # What this does is concatenate multiple (if available) source files.
        # This gets around GLSL's composition issues, which means we can have
        # functions that get called at each step in a ray tracing process, for
        # instance, that can still share ray tracing code between multiple
        # files.
        if not isinstance(source, (tuple, list)):
            source = (source, )
        source = (
            "header.inc.glsl",
            "known_uniforms.inc.glsl",
        ) + tuple(source)
        full_source = []
        for fn in source:
            if os.path.isfile(fn):
                sh_directory = ""
            else:
                sh_directory = os.path.join(os.path.dirname(__file__),
                                            "shaders")
            fn = os.path.join(sh_directory, fn)
            if not os.path.isfile(fn):
                raise YTInvalidShaderType(fn)
            full_source.append(open(fn, "r").read())
        return "\n\n".join(full_source)

    def _enable_null_shader(self):
        source = _NULL_SOURCES[self.shader_type]
        self.compile(source=source)

    def compile(self, source=None, parameters=None):
        if source is None:
            source = self.source
            if source is None:
                raise RuntimeError
        if parameters is not None:
            raise NotImplementedError
        source = self._get_source(source)
        shader_type_enum = getattr(GL, f"GL_{self.shader_type.upper()}_SHADER")
        shader = GL.glCreateShader(shader_type_enum)
        # We could do templating here if we wanted.
        self.shader_source = source
        GL.glShaderSource(shader, source)
        GL.glCompileShader(shader)
        result = GL.glGetShaderiv(shader, GL.GL_COMPILE_STATUS)
        if not (result):
            raise RuntimeError(GL.glGetShaderInfoLog(shader))
        self._shader = shader

    def setup_blend(self):
        GL.glEnable(GL.GL_BLEND)
        if self.use_separate_blend:
            GL.glBlendEquationSeparate(*self.blend_equation_separate)
            GL.glBlendFuncSeparate(*self.blend_func_separate)
        else:
            GL.glBlendEquation(self.blend_equation)
            GL.glBlendFunc(*self.blend_func)
        GL.glEnable(GL.GL_DEPTH_TEST)
        GL.glDepthFunc(self.depth_test)

    @property
    def shader(self):
        if self._shader is None:
            try:
                self.compile()
            except RuntimeError as exc:
                print(exc)
                for line_num, line in enumerate(
                        self.shader_source.split("\n")):
                    print(f"{line_num + 1:05}: {line}")
                self._enable_null_shader()
        return self._shader

    def delete_shader(self):
        if None not in (self._shader, GL.glDeleteShader):
            GL.glDeleteShader(self._shader)
            self._shader = None

    def __del__(self):
        # This is not guaranteed to be called
        self.delete_shader()