class DirectedGraphData(GraphData):
    def __init__(self):
        self.__nodes__ = SampleSet()
        self.__edges__ = SampleSet()
        self.__traits__ = {}
        self.__neighbor_sets__ = {}

        self.__savepoint_set__ = False
        self.__savepoint_name__ = None
        self.__undo_log__ = []
        #
        self.__in_neighbor_sets__ = {}
        self.__out_neighbor_sets__ = {}
        self.__trait_dict_type__ = dict

    def add_node(self, node):
        if node in self.__nodes__:
            return

        if self.__savepoint_set__:
            self.__undo_log__.append((GraphData.__ADD_NODE__, node))

        self.__nodes__.add(node)
        self.__neighbor_sets__[node] = set()
        #
        self.__in_neighbor_sets__[node] = set()
        self.__out_neighbor_sets__[node] = set()

    def delete_node(self, node):
        if node not in self.__nodes__:
            return

        if self.__savepoint_set__:
            self.__undo_log__.append((GraphData.__DEL_NODE__, node, \
                [(node, n) for n in self.__out_neighbor_sets__[node]] + \
                [(n, node) for n in self.__in_neighbor_sets__[node]]))

        for neighbor in self.__neighbor_sets__[node]:
            self.__neighbor_sets__[neighbor].remove(node)
        del self.__neighbor_sets__[node]
        self.__nodes__.remove(node)
        for trait, trait_dict in self.__traits__.items():
            if node in trait_dict:
                del trait_dict[node]
        #
        for neighbor in self.__in_neighbor_sets__[node]:
            self.__out_neighbor_sets__[neighbor].remove(node)
            self.__edges__.remove((neighbor, node))
        for neighbor in self.__out_neighbor_sets__[node]:
            self.__in_neighbor_sets__[neighbor].remove(node)
            self.__edges__.remove((node, neighbor))

        del self.__in_neighbor_sets__[node]
        del self.__out_neighbor_sets__[node]

    def add_edge(self, source, target):
        if (source, target) in self.__edges__:
            return

        if self.__savepoint_set__:
            self.__undo_log__.append(
                (GraphData.__ADD_EDGE__, (source, target)))

        self.__edges__.add((source, target))
        self.__neighbor_sets__[source].add(target)
        self.__neighbor_sets__[target].add(source)
        #
        self.__out_neighbor_sets__[source].add(target)
        self.__in_neighbor_sets__[target].add(source)

    def delete_edge(self, source, target):
        if (source, target) not in self.__edges__:
            return

        if self.__savepoint_set__:
            self.__undo_log__.append(
                (GraphData.__DEL_EDGE__, (source, target)))

        if (target, source) not in self.__edges__:
            self.__neighbor_sets__[source].remove(target)
            self.__neighbor_sets__[target].remove(source)
        self.__in_neighbor_sets__[target].remove(source)
        self.__out_neighbor_sets__[source].remove(target)
        edge = (source, target)
        self.__edges__.remove(edge)
        for trait, trait_dict in self.__traits__.items():
            if edge in trait_dict:
                del trait_dict[edge]

    def is_directed(self):
        return True

    def has_edge(self, source, target):
        return (source, target) in self.__edges__

    # Caution: Returns editable copy!
    def out_neighbors(self, node):
        return self.__out_neighbor_sets__[node]

    # Caution: Returns editable copy!
    def in_neighbors(self, node):
        return self.__in_neighbor_sets__[node]
class GraphData:
    def __init__(self):
        self.__nodes__ = SampleSet()
        self.__edges__ = SampleSet()
        self.__traits__ = {}
        self.__savepoint_set__ = False
        self.__savepoint_name__ = None
        self.__undo_log__ = []

        self.__neighbor_sets__ = {}
        # UndirectedDict is defined below in this file.
        self.__trait_dict_type__ = UndirectedDict

    __ADD_TRAIT__ = 0
    __ADD_NODE__ = 1
    __ADD_EDGE__ = 2
    __DEL_NODE__ = 3
    __DEL_EDGE__ = 4

    @classmethod
    def FromNetworkX(cls, nx_graph):
        v2 = nx.__version__ >= '2.0'

        G = cls()
        for node in nx_graph.nodes():
            G.add_node(node)
            if v2:
                d = nx_graph.nodes[node]
            else:
                d = nx_graph.node[node]
            for attribute, value in d.items():
                if attribute not in G.__traits__:
                    G.add_trait(attribute)
                G[attribute][node] = value

        for (a, b) in nx_graph.edges():
            G.add_edge(a, b)
            if v2:
                d = nx_graph.edges[(a, b)]
            else:
                d = nx_graph.edge[(a, b)]
            for attribute, value in d.items():
                if attribute not in G.__traits__:
                    G.add_trait(attribute)
                G[attribute][(a, b)] = value

        return G

    def add_trait(self, name):
        if name not in self.__traits__:
            self.__traits__[name] = \
                SavepointDictWrapper(self.__trait_dict_type__())
            if self.__savepoint_set__:
                self.__undo_log__.append((GraphData.__ADD_TRAIT__, name))

    def add_node(self, node):
        if node in self.__nodes__:
            return

        if self.__savepoint_set__:
            self.__undo_log__.append((GraphData.__ADD_NODE__, node))

        self.__nodes__.add(node)
        self.__neighbor_sets__[node] = set()

    def delete_node(self, node):
        if node not in self.__nodes__:
            return

        if self.__savepoint_set__:
            self.__undo_log__.append((GraphData.__DEL_NODE__, node, \
                [(node, n) for n in self.__neighbor_sets__[node]]))

        for neighbor in self.__neighbor_sets__[node]:
            self.__neighbor_sets__[neighbor].remove(node)
            self.__edges__.remove((min(node, neighbor), max(node, neighbor)))
        del self.__neighbor_sets__[node]
        self.__nodes__.remove(node)
        for trait, trait_dict in self.__traits__.items():
            if node in trait_dict:
                del trait_dict[node]

    def add_edge(self, source, target):
        source_ = min(source, target)
        target_ = max(source, target)

        if (source_, target_) in self.__edges__:
            return

        if self.__savepoint_set__:
            self.__undo_log__.append(
                (GraphData.__ADD_EDGE__, (source_, target_)))

        self.__edges__.add((source_, target_))
        self.__neighbor_sets__[source].add(target)
        self.__neighbor_sets__[target].add(source)

    def delete_edge(self, source, target):
        self.__neighbor_sets__[source].remove(target)
        self.__neighbor_sets__[target].remove(source)
        source_ = min(source, target)
        target_ = max(source, target)

        edge = (source_, target_)
        if edge not in self.__edges__:
            return

        if self.__savepoint_set__:
            self.__undo_log__.append((GraphData.__DEL_EDGE__, edge))

        self.__edges__.remove(edge)
        for trait, trait_dict in self.__traits__.items():
            if edge in trait_dict:
                del trait_dict[edge]

    def nodes(self):
        return self.__nodes__

    def edges(self):
        return self.__edges__

    def num_nodes(self):
        return len(self.__nodes__)

    def num_edges(self):
        return len(self.__edges__)

    def is_directed(self):
        return False

    def has_node(self, node):
        return node in self.__nodes__

    def has_edge(self, source, target):
        return (min(source, target), max(source, target)) in self.__edges__

    def random_node(self):
        return self.__nodes__.randomly_sample()

    def random_edge(self):
        return self.__edges__.randomly_sample()

    # Caution: Returns editable copy!
    def neighbors(self, node):
        return self.__neighbor_sets__[node]

    def __getitem__(self, key):
        if key not in self.__traits__:
            raise ValueError("Error! Trait %s not found in graph. " % key + \
                "Use add_trait(%s) first." % key)
        return self.__traits__[key]

    def __setitem__(self, key, value):
        raise ValueError("Error! Traits must be set via add_trait().")

    def copy(self):
        if self.is_directed():
            c = DirectedGraphData()
        else:
            c = GraphData()
        for node in self.__nodes__:
            c.add_node(node)
        for (a, b) in self.__edges__:
            c.add_edge(a, b)
        for trait_name, trait_dict in self.__traits__.items():
            c.add_trait(trait_name)
            for element, trait_value in trait_dict.items():
                c[trait_name][element] = trait_value
        return c

    # `name` is for debugging purposes only.
    def set_savepoint(self, name=None):
        if self.__savepoint_set__:
            if self.__savepoint_name__ is not None:
                old_name_str = " (name of previous savepoint: %s)" % \
                    self.__savepoint_name__
            else:
                old_name_str = ""
            raise ValueError("Error! Setting a graph_data savepoint when " + \
                "one is already set!" + old_name_str)
        self.__savepoint_name__ = name
        self.__savepoint_set__ = True
        self.__undo_log__ = []
        for _, trait_dict in self.__traits__.items():
            trait_dict.set_savepoint()

    def restore_to_savepoint(self):
        # Temporarily prevent these 'changes' from being put in the log.
        self.__savepoint_set__ = False
        for _, trait_dict in self.__traits__.items():
            trait_dict.restore_to_savepoint()
            # Don't bother recording the 'changes' made below.
            trait_dict.clear_savepoint()

        for i in range(0, len(self.__undo_log__)):
            undo_data = self.__undo_log__[(len(self.__undo_log__) - 1) - i]
            undo_type = undo_data[0]
            if undo_type == GraphData.__ADD_TRAIT__:
                del self.__traits__[undo_data[1]]
            elif undo_type == GraphData.__ADD_NODE__:
                self.delete_node(undo_data[1])
            elif undo_type == GraphData.__ADD_EDGE__:
                self.delete_edge(undo_data[1][0], undo_data[1][1])
            elif undo_type == GraphData.__DEL_NODE__:
                self.add_node(undo_data[1])
                for edge in undo_data[2]:
                    self.add_edge(edge[0], edge[1])
            else:  # GraphData.__DEL_EDGE__
                self.add_edge(undo_data[1][0], undo_data[1][1])

        # Restore the savepoint_set status.
        self.__savepoint_set__ = True
        for _, trait_dict in self.__traits__.items():
            trait_dict.set_savepoint()

    def clear_savepoint(self):
        self.__savepoint_set__ = False
        self.__undo_log__ = []
        self.__savepoint_name__ = None
        for _, trait_dict in self.__traits__.items():
            trait_dict.clear_savepoint()