class Class(Schema): @match(basestring, basestring, object, many(Field)) def __init__(self, name, docs, constructor, *fields): self.name = name self.docs = docs self.constructor = constructor self.fields = OrderedDict() for f in fields: self.fields[f.name] = f @match(basestring, object, many(Field)) def __init__(self, name, constructor, *fields): self.__init__(name, "", constructor, *fields) @match(MappingNode) def load(self, node): loaded = {} for k, v in node.value: key = k.value if key not in self.fields: raise SchemaError("no such field: %s\n%s" % (key, k.start_mark)) f = self.fields[key] loaded[f.alias or f.name] = f.type.load(v) for f in self.fields.values(): key = (f.alias or f.name) if key not in loaded: if f.default is REQUIRED: raise SchemaError("required field '%s' is missing\n%s" % (f.name, node.start_mark)) else: loaded[key] = f.default try: return self.constructor(**loaded) except SchemaError, e: raise SchemaError("%s\n\n%s" % (e, node.start_mark))
class Class(Schema): @match(basestring, basestring, object, many(Field)) def __init__(self, name, docs, constructor, *fields, **kwargs): self.name = name self.docs = docs if isinstance(constructor, Field): fields = (constructor, ) + fields constructor = OrderedDict elif not callable(constructor): raise TypeError("constructor must be callable") self.constructor = constructor self.fields = OrderedDict() for f in fields: self.fields[f.name] = f self.strict = kwargs.pop("strict", True) if kwargs: raise TypeError("no such arg(s): %s" % ", ".join(kwargs.keys())) @match(basestring, object, many(Field)) def __init__(self, name, constructor, *fields): self.__init__(name, "", constructor, *fields) @match(MappingNode) def load(self, node): loaded = {} for k, v in node.value: key = k.value if key in self.fields: f = self.fields[key] elif self.strict: raise SchemaError("no such field: %s\n%s" % (key, k.start_mark)) else: f = Field(key, Any()) loaded[f.alias or f.name] = f.type.load(v) for f in self.fields.values(): key = (f.alias or f.name) if key not in loaded: if f.default is REQUIRED: raise SchemaError("required field '%s' is missing\n%s" % (f.name, node.start_mark)) elif f.default is not OMIT: loaded[key] = f.default try: return self.constructor(**loaded) except SchemaError, e: raise SchemaError("%s\n\n%s" % (e, node.start_mark))
class ATest(object): @match(str) def __init__(self, x): "init1" self.x = x self.case = 1 @match(int) def __init__(self, y): "init2" self.__init__(str(y)) self.case = 2 @match(str, int) def foo(self, x, y): "foo1" return 1, x, y @match(int, str) def foo(self, x, y): "foo2" return 2, x, y @match([many(int)]) def foo(self, lst): "foo3" return 3, lst
class Union(Schema): @match(many(Schema, min=1)) def __init__(self, *schemas): self.schemas = schemas @property def name(self): return "(%s)" % "|".join(s.name for s in self.schemas) @match(Node) def load(self, node): for s in self.schemas: try: return s.load(node) except SchemaError, e: pass raise SchemaError("expecting one of (%s), got %s\n%s" % ("|".join((s.name for s in self.schemas)), node.tag, node.start_mark))
class Union(Schema): """Unions must be able to descriminate between their schemas. The means to descriminate can be somewhat flexible. A descriminator is computed according to the following algorithm: Logically the descriminator consists of the following components: 1. The type. This is sufficient for scalar values and seqences, but we need more to descriminate maps into distinct types. 2. For maps, a further descriminator is computed based on a signature composed of all required fields of type Constant. """ @match(many(Schema, min=1)) def __init__(self, *schemas): self.schemas = schemas self.tags = {} for s in self.schemas: if isinstance(s, Class): continue t = _tag(s) if t in self.tags: raise ValueError("ambiguous union: %s appears multiple times" % t) else: self.tags[t] = s self.signatures = [] self.constants = {} for s in self.schemas: if not isinstance(s, Class): continue cls_sig = _sig(s) for sig in self.signatures: if not cls_sig.descriminates(sig): raise ValueError("ambiguous union: %s, %s" % (sig, cls_sig)) else: self.signatures.append(cls_sig) for f in s.fields.values(): if f.required and isinstance(f.type, Constant): if f.name not in self.constants: self.constants[f.name] = {} if f.type.value not in self.constants[f.name]: self.constants[f.name][f.type.value] = set() self.constants[f.name][f.type.value].add(s) for s in self.schemas: if not isinstance(s, Class): continue for f in s.fields.values(): if not isinstance(f.type, Constant) and f.name in self.constants: raise ValueError( "ambiguous union: '%s' both constant and unconstrained" % f.name) if self.signatures and "map" in self.tags: raise ValueError("ambiguous union: map and %s" % ", ".join(str(s) for s in self.signatures)) @property def name(self): return "(%s)" % "|".join(s.name for s in self.schemas) @property def docname(self): return "(%s)" % "|".join(s.docname for s in self.schemas) @match(Node) def load(self, node): t = _tag(node) if self.signatures and t == "map": candidates = set(s for s in self.schemas if isinstance(s, Class)) for k, v in node.value: if v.tag.endswith(":map") or v.tag.endswith(":seq"): continue if k.value in self.constants and v.value in self.constants[ k.value]: candidates.intersection_update( self.constants[k.value][v.value]) if len(candidates) == 1: s = candidates.pop() return s.load(node) else: raise SchemaError( "expecting one of (%s), got %s\n%s" % ("|".join(str(s) for s in self.signatures), t, node.start_mark)) # in case this is an union contains constant(s), and t is a 'string' of the value "foo" from the yaml, # the user might have intended a constant named "foo". if t not in self.tags: v = node.value # `v` could have an unhashable type (like 'list') which results in a type error while the `in` check if (isinstance(v, str) or isinstance(v, unicode)): if v not in self.tags: raise SchemaError( "expecting one of (%s), got %s(%s)" % ("|".join( str(s) for s in self.tags.keys() + self.signatures), t, v)) return self.tags[v].load(node) # it doesn't seem like a constant hence not fit within the union raise SchemaError( "expecting one of (%s), got %s" % ("|".join(str(s) for s in self.tags.keys() + self.signatures), t)) return self.tags[t].load(node) @property def traversal(self): for s in self.schemas: for t in s.traversal: yield t
# limitations under the License. from yaml import ScalarNode, SequenceNode, MappingNode, CollectionNode, Node, compose, compose_all, serialize, \ serialize_all from forge.match import choice, match, many from StringIO import StringIO from .schema import _scalar2py # modes LEAF_AS_NODE = "LEAF_AS_NODE" LEAF_AS_PYTHON = "LEAF_AS_PYTHON" LEAF_AS_STRING = "LEAF_AS_STRING" @match(MappingNode, many(Node)) def traversal(node, *parents): yield node for k, v in node.value: for n in traversal(k): yield n for n in traversal(v): yield n @match(SequenceNode) def traversal(node): yield node for v in node.value: for n in traversal(v): yield n
def test_giant_switch(): OBJECT = Action("object") BAZ = Action("Baz") FOO_TYPE = Action("Foo") FOO_VALUE = Action("FOO") FOO_BAZ = Action("Foo, Baz") FOO_OBJECT = Action("Foo, object") OBJECT_THREE = Action("object, 3") OBJECT_OBJECT = Action("object, object") INTS = Action("ints") THREES = Action("threes") ONE_TO_FOUR = Action("1, 2, 3, 4") LIST_OF_INT = Action("list-of-int") LIST_OF_ZERO = Action("list-of-zero") TUPLE_OF_INT = Action("tuple-of-int") PAIRS = Action("PAIRS") frag = choice( when(object, OBJECT), when(Baz, BAZ), when(Foo, FOO_TYPE), when(FOO, FOO_VALUE), when(one(Foo, Baz), FOO_BAZ), when(one(Foo, object), FOO_OBJECT), when(one(Bar, 3), FOO_OBJECT), when(one(object, 3), OBJECT_THREE), when(one(object, object), OBJECT_OBJECT), when(one(int, many(int)), INTS), when(one(3, many(3)), THREES), when(one(1, 2, 3, 4), ONE_TO_FOUR), when([many(int)], LIST_OF_INT), when([0], LIST_OF_ZERO), when((int,), TUPLE_OF_INT), when(many(str, int), PAIRS) ) n = compile(frag) assert n.apply(Foo()) == FOO_TYPE assert n.apply(Bar()) == FOO_TYPE assert n.apply(Baz()) == BAZ assert n.apply(FOO) == FOO_VALUE assert n.apply(Foo(), Baz()) == FOO_BAZ assert n.apply(Foo(), 3) == FOO_OBJECT assert n.apply(Bar(), 3) == FOO_OBJECT assert n.apply(object(), object()) == OBJECT_OBJECT assert n.apply(3) == THREES assert n.apply(3, 3) == THREES assert n.apply(3, 3, 3) == THREES assert n.apply(3, 3, 3, 3) == THREES assert n.apply(1, 2, 3, 4) == ONE_TO_FOUR assert n.apply(0, 1, 2, 3) == INTS assert n.apply([1, 2, 3, 4]*100) == LIST_OF_INT assert n.apply([0]) == LIST_OF_ZERO assert n.apply((0,)) == TUPLE_OF_INT assert n.apply("one", 1, "two", 2, "three", 3) == PAIRS try: n.apply("one", 1, "two", 2, "three") assert False, "expected MatchError" except MatchError: pass
def test_init_doc(): for i in range(1, 3): assert "init%i" % i in ATest.__init__.__doc__ assert "init4" not in ATest.__init__.__doc__ for i in range(1, 4): assert "init%i" % i in Sub.__init__.__doc__ def test_method_doc(): for i in range(1, 4): assert "foo%i" % i in ATest.foo.__doc__ assert "foo5" not in ATest.foo.__doc__ for i in range(1, 5): assert "foo%i" % i in Sub.foo.__doc__ @match(many(int)) def fdsa(*args): return args def test_fdsa(): assert fdsa() == () assert fdsa(1) == (1,) assert fdsa(1, 2) == (1, 2) @match(int) def fib(n): return fib(n-1) + fib(n-2) @match(choice(0, 1)) def fib(n): return n