from functools import reduce
from itertools import chain

from LassyExtraction.extraction import order_nodes, is_gap, is_copy, _head_deps, _mod_deps
from LassyExtraction.graphutils import *
from LassyExtraction.milltypes import polarize_and_index_many, polarize_and_index, WordType, \
    PolarizedIndexedType, ColoredType, AtomicType, depolarize
from LassyExtraction.transformations import _cats_of_type

from LassyExtraction.viz import ToGraphViz

ProofNet = Set[Tuple[int, int]]

placeholder = AtomicType('_')


class ProofError(AssertionError):
    def __init__(self, message: str):
        super().__init__(message)


def match(proofnet: ProofNet, positive: WordType,
          negative: WordType) -> ProofNet:
    if positive != negative:
        raise ProofError('Formulas are not equal.\t{}\t{}'.format(
            positive, negative))
    if any(map(lambda x: not is_indexed(x), [positive, negative])):
        raise ProofError(
            'Input formulas are not fully indexed.\t{}\t{}'.format(
                positive, negative))
    if isinstance(positive, PolarizedIndexedType):
}
_pt_dict = {
    'adj': 'ADJ',
    'bw': 'BW',
    'let': 'LET',
    'lid': 'LID',
    'n': 'N',
    'spec': 'SPEC',
    'tsw': 'TSW',
    'tw': 'TW',
    'vg': 'VG',
    'vnw': 'VNW',
    'vz': 'VZ',
    'ww': 'WW'
}
_cat_dict = {k: AtomicType(v) for k, v in _cat_dict.items()}
_pos_dict = {k: AtomicType(v) for k, v in _pos_dict.items()}
_pt_dict = {k: AtomicType(v) for k, v in _pt_dict.items()}
# Head and modifier dependencies
_head_deps = {'hd', 'rhd', 'whd', 'cmp', 'crd'}
_mod_deps = {'mod', 'predm', 'app'}
# Obliqueness Hierarchy
_obliqueness_order = (
    ('mod', 'app', 'predm'),  # modifiers
    ('body', 'rhd_body', 'whd_body'),  # clause bodies
    ('svp', ),  # phrasal verb part
    ('ld', 'me', 'vc'),  # verb complements
    ('predc', 'obj2', 'se', 'pc', 'hdf'),  # verb secondary arguments
    ('obj1', ),  # primary object
    ('pobj', ),  # preliminary object
    ('su', ),  # primary subject
"""

from dataclasses import dataclass

from itertools import chain
from functools import reduce
from operator import add

from LassyExtraction.aethel import ProofNet, AxiomLinks
from LassyExtraction.milltypes import (WordType, AtomicType, binarize_polish,
                                       FunctorType, get_polarities_and_indices,
                                       ModalType)
from LassyExtraction.extraction import CatDict, PtDict
from typing import Optional

MWU = AtomicType('_MWU')

_atom_collations = {'SPEC': 'NP'}


def make_atom_set() -> list[AtomicType]:
    pts = set(PtDict.values())
    cats = set(CatDict.values())
    rem = set(map(AtomicType, _atom_collations.keys()))
    return sorted(pts.union(cats).difference(rem).union({MWU}),
                  key=lambda a: str(a))


_atom_set = make_atom_set()