Ejemplo n.º 1
0
class Class(Schema):

    @match(basestring, basestring, object, many(Field))
    def __init__(self, name, docs, constructor, *fields):
        self.name = name
        self.docs = docs
        self.constructor = constructor
        self.fields = OrderedDict()
        for f in fields:
            self.fields[f.name] = f

    @match(basestring, object, many(Field))
    def __init__(self, name, constructor, *fields):
        self.__init__(name, "", constructor, *fields)

    @match(MappingNode)
    def load(self, node):
        loaded = {}
        for k, v in node.value:
            key = k.value
            if key not in self.fields:
                raise SchemaError("no such field: %s\n%s" % (key, k.start_mark))
            f = self.fields[key]
            loaded[f.alias or f.name] = f.type.load(v)
        for f in self.fields.values():
            key = (f.alias or f.name)
            if key not in loaded:
                if f.default is REQUIRED:
                    raise SchemaError("required field '%s' is missing\n%s" % (f.name, node.start_mark))
                else:
                    loaded[key] = f.default
        try:
            return self.constructor(**loaded)
        except SchemaError, e:
            raise SchemaError("%s\n\n%s" % (e, node.start_mark))
Ejemplo n.º 2
0
class Class(Schema):
    @match(basestring, basestring, object, many(Field))
    def __init__(self, name, docs, constructor, *fields, **kwargs):
        self.name = name
        self.docs = docs
        if isinstance(constructor, Field):
            fields = (constructor, ) + fields
            constructor = OrderedDict
        elif not callable(constructor):
            raise TypeError("constructor must be callable")
        self.constructor = constructor
        self.fields = OrderedDict()
        for f in fields:
            self.fields[f.name] = f
        self.strict = kwargs.pop("strict", True)
        if kwargs:
            raise TypeError("no such arg(s): %s" % ", ".join(kwargs.keys()))

    @match(basestring, object, many(Field))
    def __init__(self, name, constructor, *fields):
        self.__init__(name, "", constructor, *fields)

    @match(MappingNode)
    def load(self, node):
        loaded = {}
        for k, v in node.value:
            key = k.value
            if key in self.fields:
                f = self.fields[key]
            elif self.strict:
                raise SchemaError("no such field: %s\n%s" %
                                  (key, k.start_mark))
            else:
                f = Field(key, Any())

            loaded[f.alias or f.name] = f.type.load(v)
        for f in self.fields.values():
            key = (f.alias or f.name)
            if key not in loaded:
                if f.default is REQUIRED:
                    raise SchemaError("required field '%s' is missing\n%s" %
                                      (f.name, node.start_mark))
                elif f.default is not OMIT:
                    loaded[key] = f.default
        try:
            return self.constructor(**loaded)
        except SchemaError, e:
            raise SchemaError("%s\n\n%s" % (e, node.start_mark))
Ejemplo n.º 3
0
class ATest(object):

    @match(str)
    def __init__(self, x):
        "init1"
        self.x = x
        self.case = 1

    @match(int)
    def __init__(self, y):
        "init2"
        self.__init__(str(y))
        self.case = 2

    @match(str, int)
    def foo(self, x, y):
        "foo1"
        return 1, x, y

    @match(int, str)
    def foo(self, x, y):
        "foo2"
        return 2, x, y

    @match([many(int)])
    def foo(self, lst):
        "foo3"
        return 3, lst
Ejemplo n.º 4
0
class Union(Schema):

    @match(many(Schema, min=1))
    def __init__(self, *schemas):
        self.schemas = schemas

    @property
    def name(self):
        return "(%s)" % "|".join(s.name for s in self.schemas)

    @match(Node)
    def load(self, node):
        for s in self.schemas:
            try:
                return s.load(node)
            except SchemaError, e:
                pass
        raise SchemaError("expecting one of (%s), got %s\n%s" % ("|".join((s.name for s in self.schemas)),
                                                                 node.tag, node.start_mark))
Ejemplo n.º 5
0
class Union(Schema):
    """Unions must be able to descriminate between their schemas. The
    means to descriminate can be somewhat flexible. A descriminator is
    computed according to the following algorithm:

    Logically the descriminator consists of the following components:

    1. The type. This is sufficient for scalar values and seqences,
       but we need more to descriminate maps into distinct types.

    2. For maps, a further descriminator is computed based on a
       signature composed of all required fields of type Constant.
    """
    @match(many(Schema, min=1))
    def __init__(self, *schemas):
        self.schemas = schemas

        self.tags = {}
        for s in self.schemas:
            if isinstance(s, Class): continue
            t = _tag(s)
            if t in self.tags:
                raise ValueError("ambiguous union: %s appears multiple times" %
                                 t)
            else:
                self.tags[t] = s

        self.signatures = []
        self.constants = {}
        for s in self.schemas:
            if not isinstance(s, Class): continue
            cls_sig = _sig(s)
            for sig in self.signatures:
                if not cls_sig.descriminates(sig):
                    raise ValueError("ambiguous union: %s, %s" %
                                     (sig, cls_sig))
            else:
                self.signatures.append(cls_sig)

            for f in s.fields.values():
                if f.required and isinstance(f.type, Constant):
                    if f.name not in self.constants:
                        self.constants[f.name] = {}
                    if f.type.value not in self.constants[f.name]:
                        self.constants[f.name][f.type.value] = set()
                    self.constants[f.name][f.type.value].add(s)

        for s in self.schemas:
            if not isinstance(s, Class): continue
            for f in s.fields.values():
                if not isinstance(f.type,
                                  Constant) and f.name in self.constants:
                    raise ValueError(
                        "ambiguous union: '%s' both constant and unconstrained"
                        % f.name)

        if self.signatures and "map" in self.tags:
            raise ValueError("ambiguous union: map and %s" %
                             ", ".join(str(s) for s in self.signatures))

    @property
    def name(self):
        return "(%s)" % "|".join(s.name for s in self.schemas)

    @property
    def docname(self):
        return "(%s)" % "|".join(s.docname for s in self.schemas)

    @match(Node)
    def load(self, node):
        t = _tag(node)
        if self.signatures and t == "map":
            candidates = set(s for s in self.schemas if isinstance(s, Class))
            for k, v in node.value:
                if v.tag.endswith(":map") or v.tag.endswith(":seq"): continue
                if k.value in self.constants and v.value in self.constants[
                        k.value]:
                    candidates.intersection_update(
                        self.constants[k.value][v.value])
            if len(candidates) == 1:
                s = candidates.pop()
                return s.load(node)
            else:
                raise SchemaError(
                    "expecting one of (%s), got %s\n%s" %
                    ("|".join(str(s)
                              for s in self.signatures), t, node.start_mark))
        # in case this is an union contains constant(s), and t is a 'string' of the value "foo" from the yaml,
        # the user might have intended a constant named "foo".
        if t not in self.tags:
            v = node.value
            # `v` could have an unhashable type (like 'list') which results in a type error while the `in` check
            if (isinstance(v, str) or isinstance(v, unicode)):
                if v not in self.tags:
                    raise SchemaError(
                        "expecting one of (%s), got %s(%s)" % ("|".join(
                            str(s)
                            for s in self.tags.keys() + self.signatures), t,
                                                               v))
                return self.tags[v].load(node)
            # it doesn't seem like a constant hence not fit within the union
            raise SchemaError(
                "expecting one of (%s), got %s" %
                ("|".join(str(s)
                          for s in self.tags.keys() + self.signatures), t))

        return self.tags[t].load(node)

    @property
    def traversal(self):
        for s in self.schemas:
            for t in s.traversal:
                yield t
Ejemplo n.º 6
0
# limitations under the License.

from yaml import ScalarNode, SequenceNode, MappingNode, CollectionNode, Node, compose, compose_all, serialize, \
    serialize_all
from forge.match import choice, match, many
from StringIO import StringIO

from .schema import _scalar2py

# modes
LEAF_AS_NODE = "LEAF_AS_NODE"
LEAF_AS_PYTHON = "LEAF_AS_PYTHON"
LEAF_AS_STRING = "LEAF_AS_STRING"


@match(MappingNode, many(Node))
def traversal(node, *parents):
    yield node
    for k, v in node.value:
        for n in traversal(k):
            yield n
        for n in traversal(v):
            yield n


@match(SequenceNode)
def traversal(node):
    yield node
    for v in node.value:
        for n in traversal(v):
            yield n
Ejemplo n.º 7
0
def test_giant_switch():

    OBJECT = Action("object")
    BAZ = Action("Baz")
    FOO_TYPE = Action("Foo")
    FOO_VALUE = Action("FOO")
    FOO_BAZ = Action("Foo, Baz")
    FOO_OBJECT = Action("Foo, object")
    OBJECT_THREE = Action("object, 3")
    OBJECT_OBJECT = Action("object, object")
    INTS = Action("ints")
    THREES = Action("threes")
    ONE_TO_FOUR = Action("1, 2, 3, 4")
    LIST_OF_INT = Action("list-of-int")
    LIST_OF_ZERO = Action("list-of-zero")
    TUPLE_OF_INT = Action("tuple-of-int")
    PAIRS = Action("PAIRS")

    frag = choice(
        when(object, OBJECT),
        when(Baz, BAZ),
        when(Foo, FOO_TYPE),
        when(FOO, FOO_VALUE),
        when(one(Foo, Baz), FOO_BAZ),
        when(one(Foo, object), FOO_OBJECT),
        when(one(Bar, 3), FOO_OBJECT),
        when(one(object, 3), OBJECT_THREE),
        when(one(object, object), OBJECT_OBJECT),
        when(one(int, many(int)), INTS),
        when(one(3, many(3)), THREES),
        when(one(1, 2, 3, 4), ONE_TO_FOUR),
        when([many(int)], LIST_OF_INT),
        when([0], LIST_OF_ZERO),
        when((int,), TUPLE_OF_INT),
        when(many(str, int), PAIRS)
    )

    n = compile(frag)

    assert n.apply(Foo()) == FOO_TYPE
    assert n.apply(Bar()) == FOO_TYPE
    assert n.apply(Baz()) == BAZ
    assert n.apply(FOO) == FOO_VALUE
    assert n.apply(Foo(), Baz()) == FOO_BAZ
    assert n.apply(Foo(), 3) == FOO_OBJECT
    assert n.apply(Bar(), 3) == FOO_OBJECT
    assert n.apply(object(), object()) == OBJECT_OBJECT
    assert n.apply(3) == THREES
    assert n.apply(3, 3) == THREES
    assert n.apply(3, 3, 3) == THREES
    assert n.apply(3, 3, 3, 3) == THREES
    assert n.apply(1, 2, 3, 4) == ONE_TO_FOUR
    assert n.apply(0, 1, 2, 3) == INTS
    assert n.apply([1, 2, 3, 4]*100) == LIST_OF_INT
    assert n.apply([0]) == LIST_OF_ZERO
    assert n.apply((0,)) == TUPLE_OF_INT
    assert n.apply("one", 1, "two", 2, "three", 3) == PAIRS
    try:
        n.apply("one", 1, "two", 2, "three")
        assert False, "expected MatchError"
    except MatchError:
        pass
Ejemplo n.º 8
0
def test_init_doc():
    for i in range(1, 3):
        assert "init%i" % i in ATest.__init__.__doc__
    assert "init4" not in ATest.__init__.__doc__
    for i in range(1, 4):
        assert "init%i" % i in Sub.__init__.__doc__

def test_method_doc():
    for i in range(1, 4):
        assert "foo%i" % i in ATest.foo.__doc__
    assert "foo5" not in ATest.foo.__doc__
    for i in range(1, 5):
        assert "foo%i" % i in Sub.foo.__doc__

@match(many(int))
def fdsa(*args):
    return args

def test_fdsa():
    assert fdsa() == ()
    assert fdsa(1) == (1,)
    assert fdsa(1, 2) == (1, 2)

@match(int)
def fib(n):
    return fib(n-1) + fib(n-2)

@match(choice(0, 1))
def fib(n):
    return n