def type_copies(dag: DAG[Node, str],
                head_deps: FrozenSet[str] = HeadDeps,
                mod_deps: FrozenSet[str] = ModDeps):
    def daughterhood_conditions(daughter: Edge[Node, str]) -> bool:
        return daughter.dep not in head_deps.union(mod_deps)

    def normalize_gap_copies(
            typecolors: List[Tuple[WordType, Set[str]]]) -> ArgSeq:
        def normalize_gap_copy(
                tc: Tuple[WordType, Set[str]]) -> Tuple[WordType, str]:
            if len(snd(tc)) == 1:
                return fst(tc), fst(list(snd(tc)))
            elif len(snd(tc)) == 2:
                color = fst(list(filter(lambda c: c not in head_deps,
                                        snd(tc))))
                return fst(
                    tc
                ).argument.argument, color if color not in mod_deps else None
            else:
                raise ExtractionError('Multi-colored copy.', meta=dag.meta)

        return list(map(normalize_gap_copy, typecolors))

    def make_polymorphic_x(initial: WordType, missing: ArgSeq) -> WordType:
        # missing = list(map(lambda pair: (fst(pair), snd(pair) if snd(pair) not in mod_deps else 'embedded'),
        #                    missing))
        return binarize_hots(missing, initial)

    def make_crd_type(poly_x: WordType, repeats: int) -> WordType:
        ret = poly_x
        while repeats:
            ret = DiamondType(argument=poly_x, result=ret, diamond='cnj')
            repeats -= 1
        return ret

    conjuncts = list(_cats_of_type(dag, 'conj'))

    gap_conjuncts = list(
        filter(lambda node: is_gap(dag, node, head_deps), conjuncts))
    if gap_conjuncts:
        raise ExtractionError('Gap conjunction.')

    # the edges coming out of each conjunct
    conj_outgoing_edges: List[Edges] = list(
        map(lambda c: dag.outgoing(c), conjuncts))

    # the list of coordinator edges coming out of each conjunct
    crds = list(
        map(lambda cg: list(filter(lambda edge: edge.dep == 'crd', cg)),
            conj_outgoing_edges))

    if any(list(map(lambda conj_group: len(conj_group) == 0, crds))):
        raise ExtractionError('Headless conjunction.', meta={'dag': dag.meta})

    # the list of non-excluded edges coming out of each conjunct
    conj_outgoing_edges = list(
        map(lambda cg: set(filter(daughterhood_conditions, cg)),
            conj_outgoing_edges))

    # the list of non-excluded nodes pointed by each conjunct
    conj_targets: List[Nodes] = list(
        map(lambda cg: set(map(lambda edge: edge.target, cg)),
            conj_outgoing_edges))

    # the list including only typed branches
    conj_targets = list(
        filter(
            lambda cg: all(
                list(
                    map(
                        lambda daughter: 'type' in dag.attribs[daughter].keys(
                        ), cg))), conj_targets))

    initial_typegroups: List[Set[WordType]] \
        = list(map(lambda conj_group: set(map(lambda daughter: dag.attribs[daughter]['type'], conj_group)),
                   conj_targets))

    if any(
            list(
                map(lambda conj_group: len(conj_group) != 1,
                    initial_typegroups))):
        raise ExtractionError('Non-polymorphic conjunction.',
                              meta={'dag': dag.meta})

    initial_types: WordTypes = list(
        map(lambda conj_group: fst(list(conj_group)), initial_typegroups))
    downsets: List[List[Nodes]] \
        = list(map(lambda group_targets:
                   list(map(lambda daughter: dag.points_to(daughter).union({daughter}),
                            group_targets)),
                   conj_targets))
    common_downsets: List[Nodes] = list(
        map(lambda downset: set.intersection(*downset), downsets))
    minimal_downsets: List[Nodes] = list(
        map(
            lambda downset: set(
                filter(
                    lambda node: len(
                        dag.pointed_by(node).intersection(downset)) == 0,
                    downset)), common_downsets))

    accounted_copies = set.union(
        *minimal_downsets) if common_downsets else set()
    all_copies = set(filter(lambda node: is_copy(dag, node), dag.nodes))

    if accounted_copies != all_copies:
        raise ExtractionError('Unaccounted copies.', meta=dag.meta)
    if any(
            list(
                map(lambda acc: 'type' not in dag.attribs[acc].keys(),
                    accounted_copies))):
        raise ExtractionError('Untyped copies.', meta=dag.meta)

    copy_colorsets = list(
        map(
            lambda downset: list(
                map(
                    lambda node:
                    (dag.attribs[node]['type'],
                     set(map(lambda edge: edge.dep, dag.incoming(node)))
                     ), downset)), minimal_downsets))

    copy_types_and_colors = list(map(normalize_gap_copies, copy_colorsets))

    polymorphic_xs = list(
        map(make_polymorphic_x, initial_types, copy_types_and_colors))
    crd_types = list(
        map(make_crd_type, polymorphic_xs, list(map(len, conj_targets))))
    secondary_crds = list(
        map(lambda crd: crd.target,
            chain.from_iterable(crd[1::] for crd in crds)))
    primary_crds = list(map(lambda crd: crd.target, map(fst, crds)))
    copy_types = {
        crd: {
            **dag.attribs[crd],
            **{
                'type': crd_type
            }
        }
        for crd, crd_type in zip(primary_crds, crd_types)
    }
    dag.attribs.update(copy_types)
    secondary_types = {
        crd: {
            **dag.attribs[crd],
            **{
                'type': AtomicType('_')
            }
        }
        for crd in secondary_crds
    }
    dag.attribs.update(secondary_types)
_PtDict = {
    'adj': 'ADJ',
    'bw': 'BW',
    'let': 'LET',
    'lid': 'LID',
    'n': 'N',
    'spec': 'SPEC',
    'tsw': 'TSW',
    'tw': 'TW',
    'vg': 'VG',
    'vnw': 'VNW',
    'vz': 'VZ',
    'ww': 'WW'
}

CatDict = {k: AtomicType(v) for k, v in _CatDict.items()}
PosDict = {k: AtomicType(v) for k, v in _PosDict.items()}
PtDict = {k: AtomicType(v) for k, v in _PtDict.items()}

# Head and modifier dependencies
HeadDeps = frozenset(['hd', 'rhd', 'whd', 'cmp', 'crd', 'det'])
ModDeps = frozenset(['mod', 'predm', 'app'])

# Obliqueness Hierarchy
ObliquenessOrder = (
    ('mod', 'app', 'predm'),  # modifiers
    ('body', 'rhd_body', 'whd_body'),  # clause bodies
    ('svp', ),  # phrasal verb part
    ('ld', 'me', 'vc'),  # verb complements
    ('predc', 'obj2', 'se', 'pc', 'hdf'),  # verb secondary arguments
    ('obj1', ),  # primary object
def type_heads_step(dag: DAG, head_deps: FrozenSet[str],
                    mod_deps: FrozenSet[str]) -> Optional[Dict[str, Dict]]:
    def make_hd_functor(result: WordType, argcs: Tuple[WordTypes,
                                                       strings]) -> WordType:
        return rebinarize(list(zip(*argcs)), result) if argcs else result

    heads_nodes: List[Tuple[Edge, List[str]]] \
        = list(map(lambda edge: (edge,
                                 order_nodes(dag, list(set(map(lambda edge: edge.target,
                                                               filter(lambda edge: edge.dep in head_deps,
                                                                      dag.outgoing(edge.source))))))),
                   filter(lambda edge: edge.dep in head_deps.difference({'crd'})
                                       and 'type' in dag.attribs[edge.source].keys()
                                       and 'type' not in dag.attribs[edge.target].keys()
                                       and not is_gap(dag, edge.target, head_deps),
                          dag.edges)))

    double_heads = list(
        map(
            lambda edge: edge.target,
            map(
                fst,
                filter(lambda pair: fst(pair).target != fst(snd(pair)),
                       heads_nodes))))
    single_heads = list(
        map(fst, filter(lambda pair: fst(pair) not in double_heads,
                        heads_nodes)))

    heading_edges: List[Tuple[Edge, List[Edge]]] \
        = list(filter(lambda pair: all(list(map(lambda out: 'type' in dag.attribs[out.target].keys(),
                                                snd(pair)))),
                      map(lambda pair: (fst(pair),
                                        list(filter(lambda edge: edge.dep not in mod_deps
                                                                 and edge != fst(pair)
                                                                 and edge.target not in double_heads,
                                                    snd(pair)))),
                          map(lambda edge: (edge, dag.outgoing(edge.source)), single_heads))))

    targets: strings = list(map(lambda pair: fst(pair).target, heading_edges))
    types: WordTypes = list(
        map(lambda pair: dag.attribs[fst(pair).source]['type'], heading_edges))

    def extract_argcs(edges: List[Edge]) -> Tuple[WordTypes, strings]:
        args = list(map(lambda edge: dag.attribs[edge.target]['type'], edges))
        cs = list(map(lambda edge: edge.dep, edges))
        return args, cs

    argcolors: List[Tuple[WordTypes, strings]] = list(
        map(lambda pair: extract_argcs(snd(pair)), heading_edges))

    head_types = [(node, make_hd_functor(res, argcs)) for node, res, argcs in zip(targets, types, argcolors)] + \
                 [(node, AtomicType('_')) for node in double_heads]

    return {
        **{
            node: {
                **dag.attribs[node],
                **{
                    'type': _type
                }
            }
            for node, _type in head_types
        }
    }
def type_copies(dag: DAG, head_deps: Set[str], mod_deps: Set[str]):
    def daughterhood_conditions(daughter: Edge) -> bool:
        return daughter.dep not in head_deps.union(mod_deps)

    def normalize_gap_copies(
        typecolors: Sequence[Tuple[WordType, Set[str]]]
    ) -> Sequence[Tuple[WordType, str]]:
        def normalize_gap_copy(
                tc: Tuple[WordType, Set[str]]) -> Tuple[WordType, str]:
            if len(snd(tc)) == 1:
                return fst(tc), fst(list(snd(tc)))
            elif len(snd(tc)) == 2:
                color = fst(list(filter(lambda c: c not in head_deps,
                                        snd(tc))))
                return fst(
                    tc
                ).argument.argument, color if color not in mod_deps else 'embedded'
            else:
                raise ExtractionError('Multi-colored copy.', meta=dag.meta)

        return list(map(normalize_gap_copy, typecolors))

    def make_polymorphic_x(
            initial: WordType, missing: Sequence[Tuple[WordType,
                                                       str]]) -> ColoredType:
        missing = list(
            map(
                lambda pair: (fst(pair), snd(pair)
                              if snd(pair) not in mod_deps else 'embedded'),
                missing))
        return binarize(_obliqueness_sort, list(map(fst, missing)),
                        list(map(snd, missing)), initial)

    def make_crd_type(poly_x: WordType, repeats: int) -> ColoredType:
        ret = poly_x
        while repeats:
            ret = ColoredType(argument=poly_x, result=ret, color='cnj')
            repeats -= 1
        return ret

    conjuncts = list(_cats_of_type(dag, 'conj'))

    gap_conjuncts = list(
        filter(lambda node: is_gap(dag, node, head_deps), conjuncts))
    if gap_conjuncts:
        raise ExtractionError('Gap conjunction.')

    conj_groups = list(map(dag.outgoing, conjuncts))
    crds = list(
        map(
            lambda conj_group: list(
                filter(lambda edge: edge.dep == 'crd', conj_group)),
            conj_groups))
    if any(list(map(lambda conj_group: len(conj_group) == 0, crds))):
        raise ExtractionError('Headless conjunction.', meta={'dag': dag.meta})
    conj_groups = list(
        map(
            lambda conj_group: list(filter(daughterhood_conditions, conj_group)
                                    ), conj_groups))
    conj_groups = list(
        map(lambda conj_group: list(map(lambda edge: edge.target, conj_group)),
            conj_groups))

    initial_types = list(
        map(
            lambda conj_group: set(
                map(lambda daughter: dag.attribs[daughter]['type'], conj_group)
            ), conj_groups))
    if any(list(map(lambda conj_group: len(conj_group) != 1, initial_types))):
        raise ExtractionError('Non-polymorphic conjunction.',
                              meta={'dag': dag.meta})

    initial_types = list(
        map(lambda conj_group: fst(list(conj_group)), initial_types))
    # todo: assert all missing args are the same
    downsets = list(
        map(
            lambda conj_group: list(
                map(lambda daughter: dag.points_to(daughter).union({daughter}),
                    conj_group)), conj_groups))
    common_downsets = list(
        map(lambda downset: set.intersection(*downset), downsets))
    minimal_downsets = list(
        map(
            lambda downset: set(
                filter(
                    lambda node: len(
                        dag.pointed_by(node).intersection(downset)) == 0,
                    downset)), common_downsets))

    accounted_copies = set.union(
        *minimal_downsets) if common_downsets else set()
    all_copies = set(filter(lambda node: is_copy(dag, node), dag.nodes))
    if accounted_copies != all_copies:
        raise ExtractionError('Unaccounted copies.', meta=dag.meta)
    if any(
            list(
                map(lambda acc: 'type' not in dag.attribs[acc].keys(),
                    accounted_copies))):
        raise ExtractionError('Untyped copies.', meta=dag.meta)

    copy_typecolors = list(
        map(
            lambda downset: list(
                map(
                    lambda node:
                    (dag.attribs[node]['type'],
                     set(map(lambda edge: edge.dep, dag.incoming(node)))
                     ), downset)), minimal_downsets))
    copy_typecolors = list(map(normalize_gap_copies, copy_typecolors))
    polymorphic_xs = list(
        map(make_polymorphic_x, initial_types, copy_typecolors))
    crd_types = list(
        map(make_crd_type, polymorphic_xs, list(map(len, conj_groups))))
    secondary_crds = list(chain.from_iterable(crd[1::] for crd in crds))
    secondary_crds = list(map(lambda crd: crd.target, secondary_crds))
    crds = list(map(fst, crds))
    crds = list(map(lambda crd: crd.target, crds))
    copy_types = {
        crd: {
            **dag.attribs[crd],
            **{
                'type': crd_type
            }
        }
        for crd, crd_type in zip(crds, crd_types)
    }
    dag.attribs.update(copy_types)
    secondary_types = {
        crd: {
            **dag.attribs[crd],
            **{
                'type': AtomicType('_')
            }
        }
        for crd in secondary_crds
    }
    dag.attribs.update(secondary_types)
def type_heads_step(dag: DAG, head_deps: Set[str],
                    mod_deps: Set[str]) -> Optional[Dict[Node, Dict]]:
    def make_functor(res: WordType, argcolors: Tuple[WordTypes,
                                                     strings]) -> ColoredType:
        return rebinarize(_obliqueness_sort, fst(argcolors), snd(argcolors),
                          res, mod_deps)

    heading_edges = list(
        filter(
            lambda edge: edge.dep in head_deps.difference(
                {'crd'}) and 'type' in dag.attribs[edge.source].keys() and
            'type' not in dag.attribs[edge.target].keys() and not is_gap(
                dag, edge.target, head_deps), dag.edges))
    heading_edges = list(
        map(lambda edge: (edge, dag.outgoing(edge.source)), heading_edges))
    heading_edges = list(
        map(
            lambda pair:
            (fst(pair),
             set(
                 map(lambda edge: edge.target,
                     filter(lambda edge: edge.dep in head_deps, snd(pair))))),
            heading_edges))
    heading_edges = list(
        map(lambda pair: (fst(pair), order_nodes(dag, snd(pair))),
            heading_edges))
    double_heads = list(
        map(
            fst,
            filter(lambda pair: fst(pair).target != fst(snd(pair)),
                   heading_edges)))
    double_heads = list(map(lambda edge: edge.target, double_heads))

    single_heads = list(
        map(fst,
            filter(lambda pair: fst(pair) not in double_heads, heading_edges)))
    heading_edges = list(
        map(lambda edge: (edge, dag.outgoing(edge.source)), single_heads))

    heading_edges = list(
        map(
            lambda pair:
            (fst(pair),
             list(
                 filter(
                     lambda edge: edge.dep not in mod_deps and
                     edge != fst(pair) and edge.target not in double_heads,
                     snd(pair)))), heading_edges))

    heading_edges = list(
        filter(
            lambda pair: all(
                list(
                    map(lambda out: 'type' in dag.attribs[out.target].keys(),
                        snd(pair)))), heading_edges))
    if not heading_edges and not double_heads:
        return
    heading_argcs = list(
        map(
            lambda pair:
            (fst(pair).target, dag.attribs[fst(pair).source]['type'],
             list(
                 zip(*map(
                     lambda out:
                     (dag.attribs[out.target]['type'], out.dep), snd(pair))))),
            heading_edges))
    head_types = {
        **{
            node: {
                **dag.attribs[node],
                **{
                    'type': make_functor(res, argcolors) if argcolors else res
                }
            }
            for (node, res, argcolors) in heading_argcs
        },
        **{
            node: {
                **dag.attribs[node],
                **{
                    'type': AtomicType('_')
                }
            }
            for node in double_heads
        }
    }
    return head_types
}
_pt_dict = {
    'adj': 'ADJ',
    'bw': 'BW',
    'let': 'LET',
    'lid': 'LID',
    'n': 'N',
    'spec': 'SPEC',
    'tsw': 'TSW',
    'tw': 'TW',
    'vg': 'VG',
    'vnw': 'VNW',
    'vz': 'VZ',
    'ww': 'WW'
}
_cat_dict = {k: AtomicType(v) for k, v in _cat_dict.items()}
_pos_dict = {k: AtomicType(v) for k, v in _pos_dict.items()}
_pt_dict = {k: AtomicType(v) for k, v in _pt_dict.items()}
# Head and modifier dependencies
_head_deps = {'hd', 'rhd', 'whd', 'cmp', 'crd'}
_mod_deps = {'mod', 'predm', 'app'}
# Obliqueness Hierarchy
_obliqueness_order = (
    ('mod', 'app', 'predm'),  # modifiers
    ('body', 'rhd_body', 'whd_body'),  # clause bodies
    ('svp', ),  # phrasal verb part
    ('ld', 'me', 'vc'),  # verb complements
    ('predc', 'obj2', 'se', 'pc', 'hdf'),  # verb secondary arguments
    ('obj1', ),  # primary object
    ('pobj', ),  # preliminary object
    ('su', ),  # primary subject
from functools import reduce
from itertools import chain

from LassyExtraction.extraction import order_nodes, is_gap, is_copy, _head_deps, _mod_deps
from LassyExtraction.graphutils import *
from LassyExtraction.milltypes import polarize_and_index_many, polarize_and_index, WordType, \
    PolarizedIndexedType, ColoredType, AtomicType, depolarize
from LassyExtraction.transformations import _cats_of_type

from LassyExtraction.viz import ToGraphViz

ProofNet = Set[Tuple[int, int]]

placeholder = AtomicType('_')


class ProofError(AssertionError):
    def __init__(self, message: str):
        super().__init__(message)


def match(proofnet: ProofNet, positive: WordType,
          negative: WordType) -> ProofNet:
    if positive != negative:
        raise ProofError('Formulas are not equal.\t{}\t{}'.format(
            positive, negative))
    if any(map(lambda x: not is_indexed(x), [positive, negative])):
        raise ProofError(
            'Input formulas are not fully indexed.\t{}\t{}'.format(
                positive, negative))
    if isinstance(positive, PolarizedIndexedType):
"""

from dataclasses import dataclass

from itertools import chain
from functools import reduce
from operator import add

from LassyExtraction.aethel import ProofNet, AxiomLinks
from LassyExtraction.milltypes import (WordType, AtomicType, binarize_polish,
                                       FunctorType, get_polarities_and_indices,
                                       ModalType)
from LassyExtraction.extraction import CatDict, PtDict
from typing import Optional

MWU = AtomicType('_MWU')

_atom_collations = {'SPEC': 'NP'}


def make_atom_set() -> list[AtomicType]:
    pts = set(PtDict.values())
    cats = set(CatDict.values())
    rem = set(map(AtomicType, _atom_collations.keys()))
    return sorted(pts.union(cats).difference(rem).union({MWU}),
                  key=lambda a: str(a))


_atom_set = make_atom_set()