Esempio n. 1
0
class Concat(Spec):
    """Concatenate two Specs together, running one after the other. Each Dimension
    of left and right must contain the same axes.

    .. example_spec::

        from scanspec.specs import Line, Concat

        spec = Concat(Line("x", 1, 3, 3), Line("x", 4, 5, 5))
    """

    left: A[
        Spec,
        schema(
            description=
            "The left-hand Spec to Concat, midpoints will appear earlier"), ]
    right: A[
        Spec,
        schema(description=
               "The right-hand Spec to Concat, midpoints will appear later"), ]

    def axes(self) -> List:
        left_axes, right_axes = self.left.axes(), self.right.axes()
        assert left_axes == right_axes, f"axes {left_axes} != {right_axes}"
        return left_axes

    def create_dimensions(self, bounds=True, nested=False) -> List[Dimension]:
        dim_left = squash_dimensions(
            self.left.create_dimensions(bounds, nested))
        dim_right = squash_dimensions(
            self.right.create_dimensions(bounds, nested))
        dim = dim_left.concat(dim_right)
        return [dim]
Esempio n. 2
0
class Circle(Region):
    """Mask contains points of axis within an xy circle of given radius

    .. example_spec::

        from scanspec.specs import Line
        from scanspec.regions import Circle

        grid = Line("y", 1, 3, 10) * ~Line("x", 0, 2, 10)
        spec = grid & Circle("x", "y", 1, 2, 0.9)
    """

    x_axis: A[str,
              schema(description="The name matching the x axis of the spec")]
    y_axis: A[str,
              schema(description="The name matching the y axis of the spec")]
    x_middle: A[float, schema(description="The central x point of the circle")]
    y_middle: A[float, schema(description="The central y point of the circle")]
    radius: A[float, schema(description="Radius of the circle", exc_min=0)]

    def axis_sets(self) -> List[Set[str]]:
        return [{self.x_axis, self.y_axis}]

    def mask(self, points: AxesPoints) -> np.ndarray:
        x = points[self.x_axis] - self.x_middle
        y = points[self.y_axis] - self.y_middle
        mask = x * x + y * y <= (self.radius * self.radius)
        return mask
Esempio n. 3
0
    def bounded(
        axis: AAxis,
        lower: A[
            float,
            schema(description="Lower bound of the first point of the line")],
        upper: A[float,
                 schema(
                     description="Upper bound of the last point of the line")],
        num: ANum,
    ) -> "Line":
        """Specify a Line by extreme bounds instead of centre points.

        .. example_spec::

            from scanspec.specs import Line

            spec = Line.bounded("x", 1, 2, 5)
        """
        half_step = (upper - lower) / num / 2
        start = lower + half_step
        if num == 1:
            # One point, stop will only be used for step size
            stop = upper + half_step
        else:
            # Many points, stop will be produced
            stop = upper - half_step
        return Line(axis, start, stop, num)
Esempio n. 4
0
class A:
    a: Annotated[int,
                 schema(max=10),
                 schema(description="type description"),
                 type_name("someInt"),
                 schema(description="field description"), ] = field(
                     metadata=schema(min=0))
Esempio n. 5
0
class Foo:
    bar: int = field(
        default=0,
        metadata=alias("foo_bar") | schema(title="foo! bar!", min=0, max=42)
        | required,
    )
    baz: Annotated[int,
                   alias("foo_baz"),
                   schema(title="foo! baz!", min=0, max=32), required] = 0
Esempio n. 6
0
class CombinationOf(Region):
    """Abstract baseclass for a combination of two regions, left and right"""

    left: A[Region, schema(description="The left-hand Region to combine")]
    right: A[Region, schema(description="The right-hand Region to combine")]

    def axis_sets(self) -> List[Set[str]]:
        axis_sets = list(
            _merge_axis_sets(self.left.axis_sets() + self.right.axis_sets()))
        return axis_sets
Esempio n. 7
0
class Model:
    id: int
    client_name: str = field(metadata=schema(max_len=255))
    sort_index: float
    # must be before fields with default value
    grecaptcha_response: str = field(metadata=schema(min_len=20, max_len=1000))
    client_phone: Optional[str] = field(default=None, metadata=schema(max_len=255))
    location: Optional[Location] = None
    contractor: Optional[PositiveInt] = None
    upstream_http_referrer: Optional[str] = field(
        default=None, metadata=schema(max_len=1023)
    )
    last_updated: Optional[datetime] = None
    skills: List[Skill] = field(default_factory=list)
Esempio n. 8
0
class Squash(Spec):
    """Squash the Dimensions together of the scan (but not the midpoints) into one
    linear stack.

    See Also:
        `why-squash-can-change-path`

    .. example_spec::

        from scanspec.specs import Line, Squash

        spec = Squash(Line("y", 1, 2, 3) * Line("x", 0, 1, 4))
    """

    spec: A[Spec, schema(description="The Spec to squash the dimensions of")]
    check_path_changes: ACheckPathChanges = True

    def axes(self) -> List:
        return self.spec.axes()

    def create_dimensions(self, bounds=True, nested=False) -> List[Dimension]:
        # TODO: if we squash we explode the size, can we avoid this?
        dims = self.spec.create_dimensions(bounds, nested)
        dim = squash_dimensions(dims, nested and self.check_path_changes)
        return [dim]
Esempio n. 9
0
class Foo:
    bar: int = field(metadata=schema(min=0, max=10))
    baz: int

    @validator
    def not_equal(self):
        if self.bar == self.baz:
            yield "bar cannot be equal to baz"
Esempio n. 10
0
class Resource:
    id: int
    tags: list[Tag] = field(
        default_factory=list,
        metadata=schema(
            description="regroup multiple resources", max_items=3, unique=True
        ),
    )
Esempio n. 11
0
class Config:
    active: bool = True
    server_options: Mapping[str, bool] = field(
        default_factory=dict, metadata=properties(pattern=r"^server_")
    )
    client_options: Mapping[
        Annotated[str, schema(pattern=r"^client_")], bool  # noqa: F722
    ] = field(default_factory=dict, metadata=properties(...))
    options: Mapping[str, bool] = field(default_factory=dict, metadata=properties)
Esempio n. 12
0
class Polygon(Region):
    """Mask contains points of axis within a rotated xy polygon

    .. example_spec::

        from scanspec.specs import Line
        from scanspec.regions import Polygon

        grid = Line("y", 3, 8, 10) * ~Line("x", 1 ,8, 10)
        spec = grid & Polygon("x", "y", [1.0, 6.0, 8.0, 2.0], [4.0, 10.0, 6.0, 1.0])
    """

    x_axis: A[str,
              schema(description="The name matching the x axis of the spec")]
    y_axis: A[str,
              schema(description="The name matching the y axis of the spec")]
    x_verts: A[
        List[float],
        schema(description="The Nx1 x coordinates of the polygons vertices",
               min_len=3), ]
    y_verts: A[
        List[float],
        schema(description="The Nx1 y coordinates of the polygons vertices",
               min_len=3), ]

    def axis_sets(self) -> List[Set[str]]:
        return [{self.x_axis, self.y_axis}]

    def mask(self, points: AxesPoints) -> np.ndarray:
        x = points[self.x_axis]
        y = points[self.y_axis]
        v1x, v1y = self.x_verts[-1], self.y_verts[-1]
        mask = np.full(len(x), False, dtype=np.int8)
        for v2x, v2y in zip(self.x_verts, self.y_verts):
            # skip horizontal edges
            if v2y != v1y:
                vmask = np.full(len(x), False, dtype=np.int8)
                vmask |= (y < v2y) & (y >= v1y)
                vmask |= (y < v1y) & (y >= v2y)
                t = (y - v1y) / (v2y - v1y)
                vmask &= x < v1x + t * (v2x - v1x)
                mask ^= vmask
            v1x, v1y = v2x, v2y
        return mask
Esempio n. 13
0
class Line(Drawing):
    start: float
    stop: float
    step: float = field(default=1, metadata=schema(exc_min=0))

    async def points(self) -> AsyncIterable[float]:
        point = self.start
        while point <= self.stop:
            yield point
            point += self.step
Esempio n. 14
0
class Product(Spec):
    """Outer product of two Specs, nesting inner within outer. This means that
    inner will run in its entirety at each point in outer.

    .. example_spec::

        from scanspec.specs import Line

        spec = Line("y", 1, 2, 3) * Line("x", 3, 4, 12)
    """

    outer: A[Spec, schema(description="Will be executed once")]
    inner: A[Spec, schema(description="Will be executed len(outer) times")]

    def axes(self) -> List:
        return self.outer.axes() + self.inner.axes()

    def create_dimensions(self, bounds=True, nested=False) -> List[Dimension]:
        dims_outer = self.outer.create_dimensions(bounds=False, nested=nested)
        dims_inner = self.inner.create_dimensions(bounds, nested=True)
        return dims_outer + dims_inner
Esempio n. 15
0
class Range(Region):
    """Mask contains points of key >= min and <= max

    >>> r = Range("x", 1, 2)
    >>> r.mask({"x": np.array([0, 1, 2, 3, 4])})
    array([False,  True,  True, False, False])
    """

    axis: A[str,
            schema(description="The name matching the axis to mask in spec")]
    min: A[float,
           schema(description="The minimum inclusive value in the region")]
    max: A[float,
           schema(description="The minimum inclusive value in the region")]

    def axis_sets(self) -> List[Set[str]]:
        return [{self.axis}]

    def mask(self, points: AxesPoints) -> np.ndarray:
        v = points[self.axis]
        mask = np.bitwise_and(v >= self.min, v <= self.max)
        return mask
Esempio n. 16
0
    def duration(
        duration: A[float,
                    schema(description="The duration of each static point")],
        num: ANum = 1,
    ) -> "Static":
        """A static spec with no motion, only a duration repeated "num" times

        .. example_spec::

            from scanspec.specs import Line, Static

            spec = Line("y", 1, 2, 3) + Static.duration(0.1)
        """

        return Static(DURATION, duration, num)
Esempio n. 17
0
class Ellipse(Region):
    """Mask contains points of axis within an xy ellipse of given radius

    .. example_spec::

        from scanspec.specs import Line
        from scanspec.regions import Ellipse

        grid = Line("y", 3, 8, 10) * ~Line("x", 1 ,8, 10)
        spec = grid & Ellipse("x", "y", 5, 5, 2, 3, 75)
    """

    x_axis: A[str,
              schema(description="The name matching the x axis of the spec")]
    y_axis: A[str,
              schema(description="The name matching the y axis of the spec")]
    x_middle: A[float,
                schema(description="The central x point of the ellipse")]
    y_middle: A[float,
                schema(description="The central y point of the ellipse")]
    x_radius: A[
        float,
        schema(description="The radius along the x axis of the ellipse",
               exc_min=0), ]
    y_radius: A[
        float,
        schema(description="The radius along the y axis of the ellipse",
               exc_min=0), ]
    angle: A[float,
             schema(description="The angle of the ellipse (degrees)")] = 0.0

    def axis_sets(self) -> List[Set[str]]:
        return [{self.x_axis, self.y_axis}]

    def mask(self, points: AxesPoints) -> np.ndarray:
        x = points[self.x_axis] - self.x_middle
        y = points[self.y_axis] - self.y_middle
        if self.angle != 0:
            # Rotate src points by -angle
            phi = np.radians(-self.angle)
            tx = x * np.cos(phi) - y * np.sin(phi)
            ty = x * np.sin(phi) + y * np.cos(phi)
            x = tx
            y = ty
        mask = (x / self.x_radius)**2 + (y / self.y_radius)**2 <= 1
        return mask
Esempio n. 18
0
class Static(Spec):
    """A static point, repeated "num" times, with "axis" at "value". Can
    be used to set axis=value at every point in a scan.

    .. example_spec::

        from scanspec.specs import Line, Static

        spec = Line("y", 1, 2, 3) + Static("x", 3)
    """

    axis: AAxis
    value: A[float, schema(description="The value at each point")]
    num: ANum = 1

    @alternative_constructor
    def duration(
        duration: A[float,
                    schema(description="The duration of each static point")],
        num: ANum = 1,
    ) -> "Static":
        """A static spec with no motion, only a duration repeated "num" times

        .. example_spec::

            from scanspec.specs import Line, Static

            spec = Line("y", 1, 2, 3) + Static.duration(0.1)
        """

        return Static(DURATION, duration, num)

    def axes(self) -> List:
        return [self.axis]

    def _repeats_from_indexes(self,
                              indexes: np.ndarray) -> Dict[str, np.ndarray]:
        return {self.axis: np.full(len(indexes), self.value)}

    def create_dimensions(self, bounds=True, nested=False) -> List[Dimension]:
        return _dimensions_from_indexes(self._repeats_from_indexes,
                                        self.axes(), self.num, bounds)
Esempio n. 19
0
class Rectangle(Region):
    """Mask contains points of axis within a rotated xy rectangle

    .. example_spec::

        from scanspec.specs import Line
        from scanspec.regions import Rectangle

        grid = Line("y", 1, 3, 10) * ~Line("x", 0, 2, 10)
        spec = grid & Rectangle("x", "y", 0, 1.1, 1.5, 2.1, 30)
    """

    x_axis: A[str,
              schema(description="The name matching the x axis of the spec")]
    y_axis: A[str,
              schema(description="The name matching the y axis of the spec")]
    x_min: A[float,
             schema(description="Minimum inclusive x value in the region")]
    y_min: A[float,
             schema(description="Minimum inclusive y value in the region")]
    x_max: A[float,
             schema(description="Maximum inclusive x value in the region")]
    y_max: A[float,
             schema(description="Maximum inclusive y value in the region")]
    angle: A[
        float,
        schema(description="Clockwise rotation angle of the rectangle")] = 0.0

    def axis_sets(self) -> List[Set[str]]:
        return [{self.x_axis, self.y_axis}]

    def mask(self, points: AxesPoints) -> np.ndarray:
        x = points[self.x_axis] - self.x_min
        y = points[self.y_axis] - self.y_min
        if self.angle != 0:
            # Rotate src points by -angle
            phi = np.radians(-self.angle)
            rx = x * np.cos(phi) - y * np.sin(phi)
            ry = x * np.sin(phi) + y * np.cos(phi)
            x = rx
            y = ry
        mask_x = np.bitwise_and(x >= 0, x <= (self.x_max - self.x_min))
        mask_y = np.bitwise_and(y >= 0, y <= (self.y_max - self.y_min))
        return mask_x & mask_y
Esempio n. 20
0
class Snake(Spec):
    """Run the Spec in reverse on every other iteration when nested inside
    another Spec. Typically created with the ``~`` operator.

    .. example_spec::

        from scanspec.specs import Line

        spec = Line("y", 1, 3, 3) * ~Line("x", 3, 5, 5)
    """

    spec: A[
        Spec,
        schema(description="The Spec to run in reverse every other iteration")]

    def axes(self) -> List:
        return self.spec.axes()

    def create_dimensions(self, bounds=True, nested=False) -> List[Dimension]:
        dims = self.spec.create_dimensions(bounds, nested)
        for dim in dims:
            dim.snake = True
        return dims
Esempio n. 21
0
    def spaced(
        x_axis: A[str,
                  schema(description="An identifier for what to move for x")],
        y_axis: A[str,
                  schema(description="An identifier for what to move for y")],
        x_start: A[float, schema(description="x centre of the spiral")],
        y_start: A[float, schema(description="y centre of the spiral")],
        radius: A[float, schema(description="radius of the spiral")],
        dr: A[float, schema(description="difference between each ring")],
        rotate: A[float,
                  schema(
                      description="How much to rotate the angle of the spiral"
                  ), ] = 0.0,
    ) -> "Spiral":
        """Specify a Spiral equally spaced in "x_axis" and "y_axis" by specifying
        the "radius" and difference between each ring of the spiral "dr"

        .. example_spec::

            from scanspec.specs import Spiral

            spec = Spiral.spaced("x", "y", 0, 0, 10, 3)
        """
        # phi = sqrt(4 * pi * num)
        # and: n_rings = phi / (2 * pi)
        # so: n_rings * 2 * pi = sqrt(4 * pi * num)
        # so: num = n_rings^2 * pi
        n_rings = radius / dr
        num = int(n_rings**2 * np.pi)
        return Spiral(
            x_axis,
            y_axis,
            x_start,
            y_start,
            radius * 2,
            radius * 2,
            num,
            rotate,
        )
Esempio n. 22
0
class Line(Spec):
    """Linearly spaced points in the given axis, with first and last points
    centred on start and stop.

    .. example_spec::

        from scanspec.specs import Line

        spec = Line("x", 1, 2, 5)
    """

    axis: AAxis
    start: A[float,
             schema(description="Midpoint of the first point of the line")]
    stop: A[float,
            schema(description="Midpoint of the last point of the line")]
    num: ANum

    def axes(self) -> List:
        return [self.axis]

    def _line_from_indexes(self, indexes: np.ndarray) -> Dict[str, np.ndarray]:
        if self.num == 1:
            # Only one point, stop-start gives length of one point
            step = self.stop - self.start
        else:
            # Multiple points, stop-start gives length of num-1 points
            step = (self.stop - self.start) / (self.num - 1)
        # self.start is the first centre point, but we need the lower bound
        # of the first point as this is where the index array starts
        first = self.start - step / 2
        return {self.axis: indexes * step + first}

    def create_dimensions(self, bounds=True, nested=False) -> List[Dimension]:
        return _dimensions_from_indexes(self._line_from_indexes, self.axes(),
                                        self.num, bounds)

    @alternative_constructor
    def bounded(
        axis: AAxis,
        lower: A[
            float,
            schema(description="Lower bound of the first point of the line")],
        upper: A[float,
                 schema(
                     description="Upper bound of the last point of the line")],
        num: ANum,
    ) -> "Line":
        """Specify a Line by extreme bounds instead of centre points.

        .. example_spec::

            from scanspec.specs import Line

            spec = Line.bounded("x", 1, 2, 5)
        """
        half_step = (upper - lower) / num / 2
        start = lower + half_step
        if num == 1:
            # One point, stop will only be used for step size
            stop = upper + half_step
        else:
            # Many points, stop will be produced
            stop = upper - half_step
        return Line(axis, start, stop, num)
Esempio n. 23
0
from dataclasses import dataclass, field
from typing import NewType

from pytest import raises

from apischema import ValidationError, deserialize, schema

Tag = NewType("Tag", str)
schema(min_len=3, pattern=r"^\w*$", examples=["available", "EMEA"])(Tag)


@dataclass
class Resource:
    id: int
    tags: list[Tag] = field(
        default_factory=list,
        metadata=schema(
            description="regroup multiple resources", max_items=3, unique=True
        ),
    )


with raises(ValidationError) as err:  # pytest check exception is raised
    deserialize(
        Resource, {"id": 42, "tags": ["tag", "duplicate", "duplicate", "bad&", "_"]}
    )
assert err.value.errors == [
    {"loc": ["tags"], "msg": "item count greater than 3 (maxItems)"},
    {"loc": ["tags"], "msg": "duplicate items (uniqueItems)"},
    {"loc": ["tags", 3], "msg": "not matching '^\\w*$' (pattern)"},
    {"loc": ["tags", 4], "msg": "string length lower than 3 (minLength)"},
Esempio n. 24
0

# Simpler with apischema


class RGB(NamedTuple):
    red: int
    green: int
    blue: int


# NewType can be used to add schema to conversion source/target
# but Annotated[str, apischema.schema(pattern=r"#[0-9A-Fa-f]{6}")] would have worked too
HexaRGB = NewType("HexaRGB", str)
# pattern is used in JSON schema and in deserialization validation
apischema.schema(pattern=r"#[0-9A-Fa-f]{6}")(HexaRGB)


@apischema.deserializer  # could be declared as a staticmethod of RGB class
def from_hexa(hexa: HexaRGB) -> RGB:
    return RGB(int(hexa[1:3], 16), int(hexa[3:5], 16), int(hexa[5:7], 16))


@apischema.serializer  # could be declared as a method/property of RGB class
def to_hexa(rgb: RGB) -> HexaRGB:
    return HexaRGB(f"#{rgb.red:02x}{rgb.green:02x}{rgb.blue:02x}")


assert (  # schema is inherited from deserialized type
    apischema.json_schema.deserialization_schema(RGB) ==
    apischema.json_schema.deserialization_schema(HexaRGB) == {
Esempio n. 25
0
class Dataclass:
    nested: SimpleDataclass
    opt: Optional[int] = field(default=None, metadata=schema(min=100))
Esempio n. 26
0
import sys
from datetime import date, datetime
from typing import NewType

from apischema import deserializer, schema, serializer

if sys.version_info < (3, 7):
    Datetime = NewType("Datetime", str)
    schema(format="date-time")(Datetime)

    @deserializer
    def to_datetime(s: Datetime) -> datetime:
        return datetime.strptime(s, "%Y-%m-%d")

    @serializer
    def from_datetime(obj: datetime) -> Datetime:
        return Datetime(obj.strftime("%Y-%m-%dT%H:%M:%S"))

    Date = NewType("Date", str)
    schema(format="date")(Date)

    @deserializer
    def to_date(s: Date) -> date:
        return date.strptime(s, "%Y-%m-%d")

    @serializer
    def from_date(obj: date) -> Date:
        return Date(obj.strftime("%Y-%m-%d"))
Esempio n. 27
0
class Spiral(Spec):
    """Archimedean spiral of "x_axis" and "y_axis", starting at centre point
    ("x_start", "y_start") with angle "rotate". Produces "num" points
    in a spiral spanning width of "x_range" and height of "y_range"

    .. example_spec::

        from scanspec.specs import Spiral

        spec = Spiral("x", "y", 1, 5, 10, 50, 30)
    """

    x_axis: A[str, schema(description="An identifier for what to move for x")]
    y_axis: A[str, schema(description="An identifier for what to move for y")]
    x_start: A[float, schema(description="x centre of the spiral")]
    y_start: A[float, schema(description="y centre of the spiral")]
    x_range: A[float, schema(description="x width of the spiral")]
    y_range: A[float, schema(description="y width of the spiral")]
    num: ANum
    rotate: A[
        float,
        schema(description="How much to rotate the angle of the spiral")] = 0.0

    def axes(self) -> List:
        # TODO: reversed from __init__ args, a good idea?
        return [self.y_axis, self.x_axis]

    def _spiral_from_indexes(self,
                             indexes: np.ndarray) -> Dict[str, np.ndarray]:
        # simplest spiral equation: r = phi
        # we want point spacing across area to be the same as between rings
        # so: sqrt(area / num) = ring_spacing
        # so: sqrt(pi * phi^2 / num) = 2 * pi
        # so: phi = sqrt(4 * pi * num)
        phi = np.sqrt(4 * np.pi * indexes)
        # indexes are 0..num inclusive, and diameter is 2x biggest phi
        diameter = 2 * np.sqrt(4 * np.pi * self.num)
        # scale so that the spiral is strictly smaller than the range
        x_scale = self.x_range / diameter
        y_scale = self.y_range / diameter
        return {
            self.y_axis:
            self.y_start + y_scale * phi * np.cos(phi + self.rotate),
            self.x_axis:
            self.x_start + x_scale * phi * np.sin(phi + self.rotate),
        }

    def create_dimensions(self, bounds=True, nested=False) -> List[Dimension]:
        return _dimensions_from_indexes(self._spiral_from_indexes, self.axes(),
                                        self.num, bounds)

    @alternative_constructor
    def spaced(
        x_axis: A[str,
                  schema(description="An identifier for what to move for x")],
        y_axis: A[str,
                  schema(description="An identifier for what to move for y")],
        x_start: A[float, schema(description="x centre of the spiral")],
        y_start: A[float, schema(description="y centre of the spiral")],
        radius: A[float, schema(description="radius of the spiral")],
        dr: A[float, schema(description="difference between each ring")],
        rotate: A[float,
                  schema(
                      description="How much to rotate the angle of the spiral"
                  ), ] = 0.0,
    ) -> "Spiral":
        """Specify a Spiral equally spaced in "x_axis" and "y_axis" by specifying
        the "radius" and difference between each ring of the spiral "dr"

        .. example_spec::

            from scanspec.specs import Spiral

            spec = Spiral.spaced("x", "y", 0, 0, 10, 3)
        """
        # phi = sqrt(4 * pi * num)
        # and: n_rings = phi / (2 * pi)
        # so: n_rings * 2 * pi = sqrt(4 * pi * num)
        # so: num = n_rings^2 * pi
        n_rings = radius / dr
        num = int(n_rings**2 * np.pi)
        return Spiral(
            x_axis,
            y_axis,
            x_start,
            y_start,
            radius * 2,
            radius * 2,
            num,
            rotate,
        )
Esempio n. 28
0
def test_int_as_float():
    assert deserialize(float, 42) == 42.0
    assert type(deserialize(float, 42)) == float
    assert deserialize(float, 42, schema=schema(min=0)) == 42.0
    with raises(ValidationError):
        deserialize(float, -1.0, schema=schema(min=0))
Esempio n. 29
0
from graphql.utilities import print_schema

from apischema import schema
from apischema.graphql import graphql_schema


class MyEnum(Enum):
    FOO = "FOO"
    BAR = "BAR"


def echo(enum: MyEnum) -> MyEnum:
    return enum


schema_ = graphql_schema(query=[echo],
                         enum_schemas={MyEnum.FOO: schema(description="foo")})
schema_str = '''\
type Query {
  echo(enum: MyEnum!): MyEnum!
}

enum MyEnum {
  """foo"""
  FOO
  BAR
}
'''
assert print_schema(schema_) == schema_str
assert graphql_sync(schema_, "{echo(enum: FOO)}").data == {"echo": "FOO"}
Esempio n. 30
0
def make_entity_class(definition: Definition,
                      support: Support) -> Type[Entity]:
    """
    We can get a set of Definitions by deserializing an ibek
    support module definition YAML file.

    This function then creates an Entity derived class from each Definition.

    See :ref:`entities`
    """
    fields: List[Tuple[str, type, Field[Any]]] = []

    # add in each of the arguments
    for arg in definition.args:
        # make_dataclass can cope with string types, so cast them here rather
        # than lookup
        metadata: Any = None
        arg_type: type
        if isinstance(arg, ObjectArg):

            def lookup_instance(id):
                try:
                    return id_to_entity[id]
                except KeyError:
                    raise ValidationError(
                        f"{id} is not in {list(id_to_entity)}")

            metadata = conversion(deserialization=Conversion(
                lookup_instance, str, Entity)) | schema(
                    extra={"vscode_ibek_plugin_type": "type_object"})
            arg_type = Entity
        elif isinstance(arg, IdArg):
            arg_type = str
            metadata = schema(extra={"vscode_ibek_plugin_type": "type_id"})
        else:
            # arg.type is str, int, float, etc.
            arg_type = getattr(builtins, arg.type)
        if arg.description:
            arg_type = A[arg_type, desc(arg.description)]
        if arg.default is Undefined:
            fld = field(metadata=metadata)
        else:
            fld = field(metadata=metadata, default=arg.default)
        fields.append((arg.name, arg_type, fld))

    # put the literal name in as 'type' for this Entity this gives us
    # a unique key for each of the entity types we may instantiate
    full_name = f"{support.module}.{definition.name}"
    fields.append(
        ("type", Literal[full_name], field(default=cast(Any, full_name))))

    # add a field so we can control rendering of the entity without having to delete
    # it
    fields.append(("entity_enabled", bool, field(default=cast(Any, True))))

    namespace = dict(__definition__=definition)

    # make the Entity derived dataclass for this EntityClass, with a reference
    # to the Definition that created it
    entity_cls = make_dataclass(full_name,
                                fields,
                                bases=(Entity, ),
                                namespace=namespace)
    deserializer(Conversion(identity, source=entity_cls, target=Entity))
    return entity_cls