예제 #1
0
def construct_graph_from_dict():
    """Construct a hypergraph dictionary mapping edge id's to vertex sets.

    Notes:
        A pair of entitysets (Nodes,Edges) such that Edges has depth 2, Nodes have depth 1, and the children of Edges is exactly the set of elements of Nodes.
          Intuitively, every element of Edges is a (hyper)edge, which is either empty or contains elements of Nodes.
          Every node in Nodes has membership in some edge in Edges.
          Since a node has depth 0 it is distinguished by its uid, properties, and memberships.
          A hypergraph is instantiated in the class Hypergraph.

    Returns:
        H: pair of entity sets (Nodes,Edges)
    """

    scenes = {
        0: ('FN', 'TH'),
        1: ('TH', 'JV'),
        2: ('BM', 'FN', 'JA'),
        3: ('JV', 'JU', 'CH', 'BM'),
        4: ('JU', 'CH', 'BR', 'CN', 'CC', 'JV', 'BM'),
        5: ('TH', 'GP'),
        6: ('GP', 'MP'),
        7: ('MA', 'GP')
    }
    return hnx.Hypergraph(scenes)
예제 #2
0
def create_hypergraph(literals: LiteralInfo):
    """ Utility function to create drawable hypergraphs """
    import hypernetx as hnx

    logic2cont, cont2logic = conversion_tables(literals)
    logic2num, num2logic = literals.numbered, literals.inv_numbered

    logic_variables = logic2cont.keys()
    continuous_variables = cont2logic.keys()

    edge_capacity = HyperEdgeContainer(0)
    node2lvar = list(logic_variables)
    lvar2node = {lvar: i for i, lvar in enumerate(node2lvar)}

    # Encode formula's
    for cvar in continuous_variables:
        lvars = cont2logic[cvar] & logic_variables
        if len(lvars) > 1:
            edge_capacity[lvars] += 1

    #hg = hypergraph.HyperGraph()
    edges = {}

    for i, (edge, capacity) in enumerate(edge_capacity):
        edges[i] = edge
        #hg.add_edge(i, set(map(logic2num.__getitem__, edge)), capacity)

    return hnx.Hypergraph(edges)
예제 #3
0
def read_hypergraph(filepath: str):
    """
    Read one or more hypergraphs from a txt file with a part-json like format

    :param filepath: path of the hypergraph file
    :type filepath: str
    :return: list of tuples of (name, hypergraph)
    :rtype: list
    """

    hgraphs = []

    with open(filepath, 'r') as graph_file:
        file_contents = graph_file.read()

    # Separate the hypergraphs based on this regex:
    # newline followed by one or more whitespace followed by newline
    file_contents = re.split(r'\n\s+\n', file_contents)

    num_hgraphs = len(file_contents)

    for i in tqdm(range(0, num_hgraphs)):
        # The name and graph are separated by '='
        graph_name, graph_dict = file_contents[i].split('=')
        graph_dict = process_graph_edges(graph_dict)
        hgraphs.append(hnx.Hypergraph(graph_dict, name=graph_name))

    return hgraphs
예제 #4
0
def load_graphs(config):
    hgraph_type = config['hgraph_type']
    variant = config['variant']
    weight_type = config['weight_type']
    if hgraph_type == "collapsed_version":
        f_hgraph = ""
    else:  #hgraph_type == "original_version"
        f_hgraph = "_original"
    if variant == "line_graph":
        f_variant = ""
    else:  # variant == "clique_expansion"
        f_variant = "_dual"
    if weight_type == "intersection_size":
        f_weight = "_is"
    else:  # weight_type == "jaccard_index"
        f_weight = "_ji"
    with open(
            path.join(APP_STATIC,
                      "uploads/current_hypergraph" + f_hgraph + ".json")) as f:
        hgraph = json.load(f)
    with open(
            path.join(
                APP_STATIC, "uploads/current" + f_variant + "_linegraph" +
                f_hgraph + ".json")) as f:
        lgraph = json.load(f)
    with open(
            path.join(
                APP_STATIC, "uploads/current" + f_variant + "_barcode" +
                f_weight + f_hgraph + ".json")) as f:
        barcode = json.load(f)
    hgraph = nx.readwrite.json_graph.node_link_data(
        hnx.Hypergraph(hgraph).bipartite())
    assign_hgraph_singletons(hgraph, lgraph['singletons'])
    return hgraph, lgraph, barcode
예제 #5
0
 def __init__(self):
     A, B, C, D, E, F, G, H, I = "A", "B", "C", "D", "E", "F", "G", "H", "I"
     AB, BC, ACD, BEH, CF, AG, ADI, ACI, CDI = (
         "AB",
         "BC",
         "ACD",
         "BEH",
         "CF",
         "AG",
         "ADI",
         "ACI",
         "CDI",
     )
     self.edgedict = {
         AB: {A, B},
         BC: {B, C},
         ACD: {A, C, D},
         BEH: {B, E, H},
         CF: {C, F},
         AG: {A, G},
         ADI: {A, D, I},
         ACI: {A, C, I},
         CDI: {C, D, I},
     }
     self.hypergraph = hnx.Hypergraph(self.edgedict, name="BigFish")
예제 #6
0
def compute_simplified_hgraph():
    jsdata = json.loads(request.get_data())
    variant = jsdata['config']['variant']
    s = int(jsdata['config']['s'])
    singleton_type = jsdata['config']['singleton_type']

    singletons = jsdata['singletons']
    cc_dict = jsdata['cc_dict']
    print(cc_dict)

    hgraph_dict = {
        he.replace(",", "|"): v_list
        for he, v_list in cc_dict.items()
    }

    write_output_hypergraph(
        hgraph_dict, path.join(APP_STATIC, "uploads/current_output.txt"))

    # If variant is clique_expansion, recover_linegraph() will give dual line graph with hgraph_dict
    lgraph = recover_linegraph(hgraph_dict, singletons, s=s)
    hgraph = hnx.Hypergraph(hgraph_dict)
    if variant == "clique_expansion":
        hgraph = hgraph.dual()
    chgraph = collapse_hypergraph(hgraph)
    # chgraph = hgraph
    chgraph = nx.readwrite.json_graph.node_link_data(chgraph.bipartite())
    if singleton_type == "grey_out":
        assign_hgraph_singletons(chgraph, singletons)
    return jsonify(hyper_data=chgraph, line_data=lgraph)
예제 #7
0
파일: lesmis.py 프로젝트: pnnl/HyperNetX
def lesmis_hypergraph_from_df(df, by="Chapter", on="Characters"):
    cols = df.columns.tolist()

    return hnx.Hypergraph({
        ".".join(map(str, t)): set(dft)
        for t, dft in df.groupby(cols[:cols.index(by) + 1])[on]
    })
예제 #8
0
파일: lesmis.py 프로젝트: toggled/HyperNetX
def hypergraph_from_df(df, by='Chapter', on='Characters'):
    cols = df.columns.tolist()

    return hnx.Hypergraph({
        '.'.join(map(str, t)): set(dft)
        for t, dft in df.groupby(cols[:cols.index(by) + 1])[on]
    })
예제 #9
0
 def __init__(self):
     A, B, C, D, E, F, G, H = 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H'
     AB, BC, ACD, BEH, CF, AG = 'AB', 'BC', 'ACD', 'BEH', 'CF', 'AG'
     self.edgedict = {
         AB: {A, B},
         BC: {B, C},
         ACD: {A, C, D},
         BEH: {B, E, H},
         CF: {C, F},
         AG: {A, G}
     }
     self.hypergraph = hnx.Hypergraph(self.edgedict, name='Fish')
예제 #10
0
def process_hypergraph_from_csv(graph_file: str):
    hgraph = {}

    with open(graph_file, 'r') as gfile:
        for line in gfile:
            line = line.rstrip().rsplit(',')
            hyperedge, vertices = line[0], line[1:]

            if hyperedge not in hgraph.keys():
                hgraph[hyperedge] = vertices
            else:
                hgraph[hyperedge] += vertices

    return hnx.Hypergraph(hgraph)
예제 #11
0
 def __init__(self):
     self.edgedict = {
         1: {'CL', 'CV', 'GE', 'GG', 'MB', 'MC', 'ME', 'MY', 'NP', 'SN'},
         2: {'IS', 'JL', 'JV', 'MB', 'ME', 'MR', 'MT', 'MY', 'PG'},
         3: {'BL', 'DA', 'FA', 'FN', 'FT', 'FV', 'LI', 'ZE'},
         4: {'CO', 'FN', 'TH', 'TM'},
         5: {'BM', 'FF', 'FN', 'JA', 'JV', 'MT', 'MY', 'VI'},
         6: {'FN', 'JA', 'JV'},
         7: {
             'BM', 'BR', 'CC', 'CH', 'CN', 'FN', 'JU', 'JV', 'PO', 'SC',
             'SP', 'SS'
         },
         8: {'FN', 'JA', 'JV', 'PO', 'SP', 'SS'}
     }
     self.hypergraph = hnx.Hypergraph(self.edgedict)
예제 #12
0
def process_hypergraph(hyper_data: str):
    """
    Returns hgraph, label dict
    """
    hgraph = {}
    label2id = {}
    vlabel2id = {}
    # label2id = {"a":"h1|h2", "b": "h3|h4", "c":"h5|h6"}
    # label2id = {'1':'v0', '2':'v1', '3':'v2', '4':'v3', '5':'v4'}
    he_id = 0
    v_id = 0
    for line in hyper_data.split("\n"):
        line = line.rstrip().rsplit(',')

        hyperedge, vertices = line[0], line[1:]
        if hyperedge != "":
            if hyperedge not in label2id.keys():
                hyperedge_label = re.sub('[\'\s]+', '', hyperedge)
                hyperedge_label = hyperedge_label.replace("\"", "")
                new_id = 'he' + str(he_id)
                # new_id = hyperedge
                he_id += 1
                label2id[hyperedge_label] = new_id
                hyperedge = new_id
            vertices_new = []
            for v in vertices:
                v_label = re.sub('[\'\s]+', '', v)
                v_label = v_label.replace("\"", "")
                if v_label != "":
                    if v_label not in vlabel2id.keys():
                        new_id = 'v' + str(v_id)
                        # new_id = v_label
                        v_id += 1
                        vlabel2id[v_label] = new_id
                        vertices_new.append(new_id)
                    else:
                        vertices_new.append(vlabel2id[v_label])
            vertices = vertices_new

            if hyperedge not in hgraph.keys():
                hgraph[hyperedge] = vertices
            else:
                hgraph[hyperedge] += vertices
    label_map = {ID: label for label, ID in label2id.items()}
    for label, ID in vlabel2id.items():
        label_map[ID] = label

    return hnx.Hypergraph(hgraph), label_map
예제 #13
0
def undo_hgraph_expansion():
    jsdata = json.loads(request.get_data())
    cc_dict = jsdata['cc_dict']
    variant = jsdata['config']['variant']
    s = int(jsdata['config']['s'])
    singleton_type = jsdata['config']['singleton_type']
    hyperedges2vertices = jsdata['hyperedges2vertices']
    singletons = jsdata['singletons']

    cc1_id = jsdata['cc_id'][0]
    cc1_keys = cc1_id.split(",")
    cc2_id = jsdata['cc_id'][1]
    cc2_keys = cc2_id.split(",")
    cc_id = jsdata['cc_id'][2]
    # cc1 and cc2: mutually exclusive
    cc_keys = list(set(cc1_keys + cc2_keys))
    cc_list = []
    for he in cc_keys:
        for v in hyperedges2vertices[he]:
            if v not in cc_list:
                cc_list.append(v)
    cc_dict[cc_id] = cc_list
    del cc_dict[cc1_id]
    del cc_dict[cc2_id]

    hgraph_dict = {
        he.replace(",", "|"): v_list
        for he, v_list in cc_dict.items()
    }
    write_output_hypergraph(
        hgraph_dict, path.join(APP_STATIC, "uploads/current_output.txt"))

    # If variant is clique_expansion, recover_linegraph() will give dual line graph with hgraph_dict
    lgraph = recover_linegraph(hgraph_dict, singletons, s=s)
    hgraph = hnx.Hypergraph(hgraph_dict)
    if variant == "clique_expansion":
        hgraph = hgraph.dual()
    chgraph = collapse_hypergraph(hgraph)
    chgraph = nx.readwrite.json_graph.node_link_data(chgraph.bipartite())
    if singleton_type == "grey_out":
        assign_hgraph_singletons(chgraph, singletons)
    return jsonify(hyper_data=chgraph, cc_dict=cc_dict, line_data=lgraph)
예제 #14
0
    def draw_graph_hnx(self):
        flow_graph_constraint_database = self.GRAPH_CLIENT[
            'flow_graph_constraint']
        edge_list = flow_graph_constraint_database['edge']

        scenes = {}
        for edge in edge_list:
            for i in range(len(edge['edge_names'])):
                source = edge['source']
                destination = edge['destination']
                source_label = source[i]['node_type'] + '-' + source[i]['type'] + '-' + \
                               source[i][source[i]['type']][-3:]
                destination_label = destination['node_type'] + '-' + destination['type'] + '-' + \
                                    destination[destination['type']][-3:]
                scenes[edge['edge_names'][i]] = [
                    source_label, destination_label
                ]
        H = hnx.Hypergraph(scenes)
        hnx.draw(H)
        plt.waitforbuttonpress()
예제 #15
0
def collapse_hypergraph(hgraph):
    chgraph = hgraph.collapse_edges()
    chgraph = chgraph.collapse_nodes()
    chgraph = chgraph.incidence_dict
    chgraph_new = {}
    for hkey in chgraph:
        hedges = list(hkey)
        hkey_new = ""
        for he in hedges:
            hkey_new += he + "|"
        hkey_new = hkey_new[:-1]
        vertices = [list(v) for v in chgraph[hkey]]
        vertices_new = []
        for v_list in vertices:
            v_new = ""
            for v in v_list:
                v_new += v + "|"
            v_new = v_new[:-1]
            vertices_new.append(v_new)
        chgraph_new[hkey_new] = vertices_new
    return hnx.Hypergraph(chgraph_new)
예제 #16
0
def process_hypergraph(hyper_data: str):
    """
    Returns hgraph, label dict
    """
    hgraph = {}
    label2id = {}
    he_id = 0
    v_id = 0
    for line in hyper_data.split("\n"):
        line = line.rstrip().rsplit(',')

        hyperedge, vertices = line[0], line[1:]
        if hyperedge != "":
            if hyperedge not in label2id.keys():
                hyperedge_label = re.sub('[\'\s]+', '', hyperedge)
                new_id = 'he' + str(he_id)
                he_id += 1
                label2id[hyperedge_label] = new_id
                hyperedge = new_id
            vertices_new = []
            for v in vertices:
                v_label = re.sub('[\'\s]+', '', v)
                if v_label != "":
                    if v_label not in label2id.keys():
                        new_id = 'v' + str(v_id)
                        v_id += 1
                        label2id[v_label] = new_id
                        vertices_new.append(new_id)
                    else:
                        vertices_new.append(label2id[v_label])
            vertices = vertices_new

            if hyperedge not in hgraph.keys():
                hgraph[hyperedge] = vertices
            else:
                hgraph[hyperedge] += vertices
    label_map = {ID: label for label, ID in label2id.items()}

    return hnx.Hypergraph(hgraph), label_map
예제 #17
0
 def __init__(self):
     A, B, C, D = "A", "B", "C", "D"
     AB, BC, ACD = "AB", "BC", "ACD"
     self.edgedict = {AB: {A, B}, BC: {B, C}, ACD: {A, C, D}}
     self.hypergraph = hnx.Hypergraph(self.edgedict, name="TriLoop")
예제 #18
0
 def __init__(self):
     A, B, C, D = 'A', 'B', 'C', 'D'
     AB, BC, ACD = 'AB', 'BC', 'ACD'
     self.edgedict = {AB: {A, B}, BC: {B, C}, ACD: {A, C, D}}
     self.hypergraph = hnx.Hypergraph(self.edgedict, name='TriLoop')
예제 #19
0
def hgraph_expansion():
    jsdata = json.loads(request.get_data())
    variant = jsdata['config']['variant']
    s = int(jsdata['config']['s'])
    singleton_type = jsdata['config']['singleton_type']
    weight_type = jsdata['config']['weight_type']

    cc_dict = jsdata['cc_dict']
    source_cc = jsdata['edge'][weight_type]['nodes_subsets']['source_cc']
    target_cc = jsdata['edge'][weight_type]['nodes_subsets']['target_cc']
    hyperedges2vertices = jsdata['hyperedges2vertices']
    singletons = jsdata['singletons']

    for cc_key in cc_dict:
        hyperedge_keys = cc_key.split(",")
        if len(set(hyperedge_keys) & set(source_cc)) > 0 and len(
                set(hyperedge_keys) & set(target_cc)) > 0:
            # if all(h1 in hyperedge_keys for h1 in source_cc) and all(h2 in hyperedge_keys for h2 in target_cc): # if source_cc and target_cc are combined
            print(hyperedge_keys, source_cc, target_cc)
            cc1_id_list = []
            cc2_id_list = []
            for he in hyperedge_keys:
                if he in source_cc:
                    cc1_id_list.append(he)
                else:
                    cc2_id_list.append(he)
            cc1_id = ""
            cc2_id = ""
            cc1 = []
            cc2 = []
            for he in cc1_id_list:
                cc1_id += he + ","
                for v in hyperedges2vertices[he]:
                    if v not in cc1:
                        cc1.append(v)
            for he in cc2_id_list:
                cc2_id += he + ","
                for v in hyperedges2vertices[he]:
                    if v not in cc2:
                        cc2.append(v)
            cc1_id = cc1_id[:-1]
            cc2_id = cc2_id[:-1]
            del cc_dict[cc_key]
            cc_dict[cc1_id] = cc1
            cc_dict[cc2_id] = cc2
            break

    hgraph_dict = {
        he.replace(",", "|"): v_list
        for he, v_list in cc_dict.items()
    }
    write_output_hypergraph(
        hgraph_dict, path.join(APP_STATIC, "uploads/current_output.txt"))

    # If variant is clique_expansion, recover_linegraph() will give dual line graph with hgraph_dict
    lgraph = recover_linegraph(hgraph_dict, singletons, s=s)
    hgraph = hnx.Hypergraph(hgraph_dict)
    if variant == "clique_expansion":
        hgraph = hgraph.dual()
    chgraph = collapse_hypergraph(hgraph)
    chgraph = nx.readwrite.json_graph.node_link_data(chgraph.bipartite())
    if singleton_type == "grey_out":
        assign_hgraph_singletons(chgraph, singletons)
    # cc1_id = cc1_id.replace(",","|")
    # cc2_id = cc2_id.replace(",","|")
    cc_removed = cc_key
    return jsonify(hyper_data=chgraph,
                   cc_dict=cc_dict,
                   line_data=lgraph,
                   cc_id=[cc1_id, cc2_id, cc_removed])
예제 #20
0
def compute_graphs(config):
    hgraph_type = config['hgraph_type']
    variant = config['variant']
    s = int(config['s'])
    singleton_type = config['singleton_type']
    weight_type = config['weight_type']
    if hgraph_type == "collapsed_version":
        f_hgraph = ""
    elif hgraph_type == "original_version":
        f_hgraph = "_original"
    # 1. load hgraph
    with open(
            path.join(APP_STATIC,
                      "uploads/current_hypergraph" + f_hgraph + ".json")) as f:
        hgraph = json.load(f)
    hgraph = hnx.Hypergraph(hgraph)
    lgraph = convert_to_line_graph(hgraph.incidence_dict,
                                   s=s,
                                   singleton_type=singleton_type)
    dual_lgraph = compute_dual_line_graph(hgraph,
                                          s=s,
                                          singleton_type=singleton_type)
    hgraph = nx.readwrite.json_graph.node_link_data(hgraph.bipartite())

    barcode_is = compute_barcode(lgraph)
    dual_barcode_is = compute_barcode(dual_lgraph)
    write_json_file(
        barcode_is,
        path.join(APP_STATIC,
                  "uploads/current_barcode_is" + f_hgraph + ".json"))
    write_json_file(
        dual_barcode_is,
        path.join(APP_STATIC,
                  "uploads/current_dual_barcode_is" + f_hgraph + ".json"))

    barcode_ji = compute_barcode(
        lgraph, weight_col="jaccard_index")  # ji: weight = 1/jaccard_index
    dual_barcode_ji = compute_barcode(dual_lgraph, weight_col="jaccard_index")

    write_json_file(
        lgraph,
        path.join(APP_STATIC,
                  "uploads/current_linegraph" + f_hgraph + ".json"))
    write_json_file(
        dual_lgraph,
        path.join(APP_STATIC,
                  "uploads/current_dual_linegraph" + f_hgraph + ".json"))

    write_json_file(
        barcode_ji,
        path.join(APP_STATIC,
                  "uploads/current_barcode_ji" + f_hgraph + ".json"))
    write_json_file(
        dual_barcode_ji,
        path.join(APP_STATIC,
                  "uploads/current_dual_barcode_ji" + f_hgraph + ".json"))
    if variant == "clique_expansion":
        lgraph = dual_lgraph
        if weight_type == "intersection_size":
            with open(
                    path.join(
                        APP_STATIC, "uploads/current_dual_barcode_is" +
                        f_hgraph + ".json")) as f:
                barcode = json.load(f)
        elif weight_type == "jaccard_index":
            with open(
                    path.join(
                        APP_STATIC, "uploads/current_dual_barcode_ji" +
                        f_hgraph + ".json")) as f:
                barcode = json.load(f)
    else:
        if weight_type == "intersection_size":
            with open(
                    path.join(
                        APP_STATIC, "uploads/current_barcode_is" + f_hgraph +
                        ".json")) as f:
                barcode = json.load(f)
        elif weight_type == "jaccard_index":
            with open(
                    path.join(
                        APP_STATIC, "uploads/current_barcode_ji" + f_hgraph +
                        ".json")) as f:
                barcode = json.load(f)
    assign_hgraph_singletons(hgraph,
                             lgraph['singletons'],
                             singleton_type=singleton_type)

    return hgraph, lgraph, barcode
예제 #21
0
def sbs_hypergraph():
    sbs = SevenBySix()
    return hnx.Hypergraph(sbs.edgedict, name='sbsh')
예제 #22
0
def sbsd_hypergraph():
    sbsd = SBSDupes()
    return hnx.Hypergraph(sbsd.edgedict)
예제 #23
0
def H():
    G = nx.karate_club_graph()
    return hnx.Hypergraph({f'e{i}': e for i, e in enumerate(G.edges())})
예제 #24
0
import json
import hypernetx as hnx
import matplotlib.pyplot as plt
from hypernetx.drawing.rubber_band import draw

with open("../toy_data/iris_graph.json") as graph:
    graph = json.load(graph)

H = hnx.Hypergraph(graph)

draw(H)
plt.show()
예제 #25
0
 def __init__(self):
     A, B, C, D, E, F, G, H = "A", "B", "C", "D", "E", "F", "G", "H"
     AB, BC, ACD, BEH, CF, AG = "AB", "BC", "ACD", "BEH", "CF", "AG"
     self.edgedict = {
         AB: {A, B},
         BC: {B, C},
         ACD: {A, C, D},
         BEH: {B, E, H},
         CF: {C, F},
         AG: {A, G},
     }
     self.hypergraph = hnx.Hypergraph(self.edgedict, name="Fish")
     state_dict = {
         "chains": {
             0: [("A", ), ("B", ), ("C", ), ("D", ), ("E", ), ("F", ),
                 ("G", ), ("H", )],
             1: [
                 ("A", "B"),
                 ("A", "C"),
                 ("A", "D"),
                 ("A", "G"),
                 ("B", "C"),
                 ("B", "E"),
                 ("B", "H"),
                 ("C", "D"),
                 ("C", "F"),
                 ("E", "H"),
             ],
             2: [("A", "C", "D"), ("B", "E", "H")],
             3: [],
         },
         "bkMatrix": {
             1:
             np.array([
                 [1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
                 [1, 0, 0, 0, 1, 1, 1, 0, 0, 0],
                 [0, 1, 0, 0, 1, 0, 0, 1, 1, 0],
                 [0, 0, 1, 0, 0, 0, 0, 1, 0, 0],
                 [0, 0, 0, 0, 0, 1, 0, 0, 0, 1],
                 [0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
                 [0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
                 [0, 0, 0, 0, 0, 0, 1, 0, 0, 1],
             ]),
             2:
             np.array([
                 [0, 0],
                 [1, 0],
                 [1, 0],
                 [0, 0],
                 [0, 0],
                 [0, 1],
                 [0, 1],
                 [1, 0],
                 [0, 0],
                 [0, 1],
             ]),
             3:
             np.array([[], []], dtype=np.int64),
         },
     }
     self.state = state_dict