Beispiel #1
0
def ensure_union_type(
    schema: s_schema.Schema,
    types: Iterable[s_types.Type],
    *,
    opaque: bool = False,
    module: Optional[str] = None,
) -> Tuple[s_schema.Schema, s_types.Type, bool]:

    from edb.schema import objtypes as s_objtypes
    from edb.schema import types as s_types

    components: Set[s_types.Type] = set()
    for t in types:
        union_of = t.get_union_of(schema)
        if union_of:
            components.update(union_of.objects(schema))
        else:
            components.add(t)

    components_list = minimize_class_set_by_most_generic(schema, components)

    if len(components_list) == 1 and not opaque:
        return schema, next(iter(components_list)), False

    seen_scalars = False
    seen_objtypes = False
    created = False

    for component in components_list:
        if component.is_object_type():
            if seen_scalars:
                raise _union_error(schema, components_list)
            seen_objtypes = True
        else:
            if seen_objtypes:
                raise _union_error(schema, components_list)
            seen_scalars = True

    if seen_scalars:
        uniontype: s_types.Type = components_list[0]
        for t1 in components_list[1:]:

            common_type = uniontype.\
                find_common_implicitly_castable_type(t1, schema)

            if common_type is None:
                raise _union_error(schema, components_list)
            else:
                uniontype = common_type
    else:
        schema, uniontype, created = s_objtypes.get_or_create_union_type(
            schema, components=components_list, opaque=opaque, module=module)

    return schema, uniontype, created
Beispiel #2
0
def get_union_type(schema,
                   types,
                   *,
                   opaque: bool = False,
                   module: typing.Optional[str] = None):
    from edb.schema import objtypes as s_objtypes

    components = set()
    for t in types:
        union_of = t.get_union_of(schema)
        if union_of:
            components.update(union_of.objects(schema))
        else:
            components.add(t)

    if len(components) == 1 and not opaque:
        return schema, next(iter(components))

    components = list(components)

    seen_scalars = False
    seen_objtypes = False

    for component in components:
        if component.is_scalar():
            if seen_objtypes:
                raise _union_error(schema, components)
            seen_scalars = True
        else:
            if seen_scalars:
                raise _union_error(schema, components)
            seen_objtypes = True

    if seen_scalars:
        uniontype = components[0]
        for t1 in components[1:]:
            uniontype = uniontype.find_common_implicitly_castable_type(
                t1, schema)

        if uniontype is None:
            raise _union_error(schema, components)
    else:
        schema, uniontype = s_objtypes.get_or_create_union_type(
            schema, components=components, opaque=opaque, module=module)

    return schema, uniontype
Beispiel #3
0
def ensure_union_type(
    schema: s_schema.Schema,
    types: Iterable[s_types.Type],
    *,
    opaque: bool = False,
    module: Optional[str] = None,
    preserve_derived: bool = False,
) -> Tuple[s_schema.Schema, s_types.Type, bool]:

    from edb.schema import objtypes as s_objtypes
    from edb.schema import types as s_types

    type_set: Set[s_types.Type] = set()
    for t in types:
        union_of = t.get_union_of(schema)
        if union_of:
            type_set.update(union_of.objects(schema))
        else:
            type_set.add(t)
    # IF we need to preserve derived types, that means that we don't
    # want to minimize them and instead keep them as is to be
    # considered in the type union.
    derived: Set[s_types.Type] = set()
    components: Set[s_types.Type] = set()
    for t in type_set:
        if (preserve_derived and isinstance(t, s_types.InheritingType)
                and t.get_is_derived(schema)):
            derived.add(t)
        else:
            components.add(t)

    components_list: List[s_types.Type]

    if all(isinstance(c, s_types.InheritingType) for c in components):
        components_list = list(
            minimize_class_set_by_most_generic(
                schema,
                cast(Set[s_types.InheritingType], components),
            ))
    else:
        components_list = list(components)
    components_list.extend(list(derived))

    if len(components_list) == 1 and not opaque:
        return schema, next(iter(components_list)), False

    seen_scalars = False
    seen_objtypes = False
    created = False

    for component in components_list:
        if isinstance(component, s_objtypes.ObjectType):
            if seen_scalars:
                raise _union_error(schema, components_list)
            seen_objtypes = True
        else:
            if seen_objtypes:
                raise _union_error(schema, components_list)
            seen_scalars = True

    if seen_scalars:
        uniontype: s_types.Type = components_list[0]
        for t1 in components_list[1:]:

            schema, common_type = (
                uniontype.find_common_implicitly_castable_type(t1, schema))

            if common_type is None:
                raise _union_error(schema, components_list)
            else:
                uniontype = common_type
    else:
        objtypes = cast(
            Sequence[s_objtypes.ObjectType],
            components_list,
        )
        schema, uniontype, created = s_objtypes.get_or_create_union_type(
            schema,
            components=objtypes,
            opaque=opaque,
            module=module,
        )

    return schema, uniontype, created