def test_preimage_random_grid_k_median_nb():    
    ds = {'name': 'MUTAG', 'dataset': '../datasets/MUTAG/MUTAG_A.txt',
          'extra_params': {}}  # node/edge symb
    Gn, y_all = loadDataset(ds['dataset'], extra_params=ds['extra_params'])
#    Gn = Gn[0:50]
    remove_edges(Gn)
    gkernel = 'marginalizedkernel'
    
    lmbda = 0.03 # termination probalility
    r_max = 5 # iteration limit for pre-image.
    l = 500 # update limit for random generation
#    alpha_range = np.linspace(0.5, 0.5, 1)
#    k = 5 # k nearest neighbors
    # parameters for GED function
    ged_cost='CHEM_1'
    ged_method='IPFP'
    saveGXL='gedlib'
    
    # number of graphs; we what to compute the median of these graphs. 
    nb_median_range = [2, 3, 4, 5, 10, 20, 30, 40, 50, 100]
    # number of nearest neighbors.
    k_range = [5, 6, 7, 8, 9, 10, 20, 30, 40, 50, 100]
    
    # find out all the graphs classified to positive group 1.
    idx_dict = get_same_item_indices(y_all)
    Gn = [Gn[i] for i in idx_dict[1]]
    
#    # compute Gram matrix.
#    time0 = time.time()
#    km = compute_kernel(Gn, gkernel, True)
#    time_km = time.time() - time0    
#    # write Gram matrix to file.
#    np.savez('results/gram_matrix_marg_itr10_pq0.03_mutag_positive.gm', gm=km, gmtime=time_km)
        
    
    time_list = []
    dis_ks_min_list = []
    sod_gs_list = []
    sod_gs_min_list = []
    nb_updated_list = []
    g_best = []
    for idx_nb, nb_median in enumerate(nb_median_range):
        print('\n-------------------------------------------------------')
        print('number of median graphs =', nb_median)
        random.seed(1)
        idx_rdm = random.sample(range(len(Gn)), nb_median)
        print('graphs chosen:', idx_rdm)
        Gn_median = [Gn[idx].copy() for idx in idx_rdm]
        
#        for g in Gn_median:
#            nx.draw(g, labels=nx.get_node_attributes(g, 'atom'), with_labels=True)
##            plt.savefig("results/preimage_mix/mutag.png", format="PNG")
#            plt.show()
#            plt.clf()                         
                    
        ###################################################################
        gmfile = np.load('results/gram_matrix_marg_itr10_pq0.03_mutag_positive.gm.npz')
        km_tmp = gmfile['gm']
        time_km = gmfile['gmtime']
        # modify mixed gram matrix.
        km = np.zeros((len(Gn) + nb_median, len(Gn) + nb_median))
        for i in range(len(Gn)):
            for j in range(i, len(Gn)):
                km[i, j] = km_tmp[i, j]
                km[j, i] = km[i, j]
        for i in range(len(Gn)):
            for j, idx in enumerate(idx_rdm):
                km[i, len(Gn) + j] = km[i, idx]
                km[len(Gn) + j, i] = km[i, idx]
        for i, idx1 in enumerate(idx_rdm):
            for j, idx2 in enumerate(idx_rdm):
                km[len(Gn) + i, len(Gn) + j] = km[idx1, idx2]
                
        ###################################################################
        alpha_range = [1 / nb_median] * nb_median
        
        time_list.append([])
        dis_ks_min_list.append([])
        sod_gs_list.append([])
        sod_gs_min_list.append([])
        nb_updated_list.append([])
        g_best.append([])   
        
        for k in k_range:
            print('\n++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n')
            print('k =', k)
            time0 = time.time()
            dhat, ghat, nb_updated = preimage_random(Gn, Gn_median, alpha_range, 
                range(len(Gn), len(Gn) + nb_median), km, k, r_max, l, gkernel)
                
            time_total = time.time() - time0 + time_km
            print('time: ', time_total)
            time_list[idx_nb].append(time_total)
            print('\nsmallest distance in kernel space: ', dhat) 
            dis_ks_min_list[idx_nb].append(dhat)
            g_best[idx_nb].append(ghat)
            print('\nnumber of updates of the best graph: ', nb_updated)
            nb_updated_list[idx_nb].append(nb_updated)
            
            # show the best graph and save it to file.
            print('the shortest distance is', dhat)
            print('one of the possible corresponding pre-images is')
            nx.draw(ghat, labels=nx.get_node_attributes(ghat, 'atom'), 
                    with_labels=True)
            plt.savefig('results/preimage_random/mutag_median_nb' + str(nb_median) + 
                        '_k' + str(k) + '.png', format="PNG")
    #        plt.show()
            plt.clf()
    #        print(ghat_list[0].nodes(data=True))
    #        print(ghat_list[0].edges(data=True))
        
            # compute the corresponding sod in graph space.
            sod_tmp, _ = ged_median([ghat], Gn_median, ged_cost=ged_cost, 
                                         ged_method=ged_method, saveGXL=saveGXL)
            sod_gs_list[idx_nb].append(sod_tmp)
            sod_gs_min_list[idx_nb].append(np.min(sod_tmp))
            print('\nsmallest sod in graph space: ', np.min(sod_tmp))
        
    print('\nsods in graph space: ', sod_gs_list)
    print('\nsmallest sod in graph space for each set of median graphs and k: ', 
          sod_gs_min_list)  
    print('\nsmallest distance in kernel space for each set of median graphs and k: ', 
          dis_ks_min_list) 
    print('\nnumber of updates of the best graph for each set of median graphs and k by IAM: ', 
          nb_updated_list)
    print('\ntimes:', time_list)
Exemple #2
0
    def generate_graph(G, pi_p_forward):
        G_new_list = [G.copy()
                      ]  # all "best" graphs generated in this iteration.
        #        nx.draw_networkx(G)
        #        import matplotlib.pyplot as plt
        #        plt.show()
        #        print(pi_p_forward)

        # update vertex labels.
        # pre-compute h_i0 for each label.
        #        for label in get_node_labels(Gn, node_label):
        #            print(label)
        #        for nd in G.nodes(data=True):
        #            pass
        if not ds_attrs['node_attr_dim']:  # labels are symbolic
            for ndi, (nd, _) in enumerate(G.nodes(data=True)):
                h_i0_list = []
                label_list = []
                for label in node_label_set:
                    h_i0 = 0
                    for idx, g in enumerate(Gn_median):
                        pi_i = pi_p_forward[idx][ndi]
                        if pi_i != node_ir and g.nodes[pi_i][
                                node_label] == label:
                            h_i0 += 1
                    h_i0_list.append(h_i0)
                    label_list.append(label)
                # case when the node is to be removed.
                if removeNodes:
                    h_i0_remove = 0  # @todo: maybe this can be added to the node_label_set above.
                    for idx, g in enumerate(Gn_median):
                        pi_i = pi_p_forward[idx][ndi]
                        if pi_i == node_ir:
                            h_i0_remove += 1
                    h_i0_list.append(h_i0_remove)
                    label_list.append(label_r)
                # get the best labels.
                idx_max = np.argwhere(
                    h_i0_list == np.max(h_i0_list)).flatten().tolist()
                if allBestNodes:  # choose all best graphs.
                    nlabel_best = [label_list[idx] for idx in idx_max]
                    # generate "best" graphs with regard to "best" node labels.
                    G_new_list_nd = []
                    for g in G_new_list:  # @todo: seems it can be simplified. The G_new_list will only contain 1 graph for now.
                        for nl in nlabel_best:
                            g_tmp = g.copy()
                            if nl == label_r:
                                g_tmp.remove_node(nd)
                            else:
                                g_tmp.nodes[nd][node_label] = nl
                            G_new_list_nd.append(g_tmp)

    #                            nx.draw_networkx(g_tmp)
    #                            import matplotlib.pyplot as plt
    #                            plt.show()
    #                            print(g_tmp.nodes(data=True))
    #                            print(g_tmp.edges(data=True))
                    G_new_list = [ggg.copy() for ggg in G_new_list_nd]
                else:
                    # choose one of the best randomly.
                    idx_rdm = random.randint(0, len(idx_max) - 1)
                    best_label = label_list[idx_max[idx_rdm]]
                    h_i0_max = h_i0_list[idx_max[idx_rdm]]

                    g_new = G_new_list[0]
                    if best_label == label_r:
                        g_new.remove_node(nd)
                    else:
                        g_new.nodes[nd][node_label] = best_label
                    G_new_list = [g_new]
        else:  # labels are non-symbolic
            for ndi, (nd, _) in enumerate(G.nodes(data=True)):
                Si_norm = 0
                phi_i_bar = np.array(
                    [0.0 for _ in range(ds_attrs['node_attr_dim'])])
                for idx, g in enumerate(Gn_median):
                    pi_i = pi_p_forward[idx][ndi]
                    if g.has_node(
                            pi_i
                    ):  #@todo: what if no g has node? phi_i_bar = 0?
                        Si_norm += 1
                        phi_i_bar += np.array([
                            float(itm) for itm in g.nodes[pi_i]['attributes']
                        ])
                phi_i_bar /= Si_norm
                G_new_list[0].nodes[nd]['attributes'] = phi_i_bar

#        for g in G_new_list:
#            import matplotlib.pyplot as plt
#            nx.draw(g, labels=nx.get_node_attributes(g, 'atom'), with_labels=True)
#            plt.show()
#            print(g.nodes(data=True))
#            print(g.edges(data=True))

# update edge labels and adjacency matrix.
        if ds_attrs['edge_labeled']:
            G_new_list_edge = []
            for g_new in G_new_list:
                nd_list = [n for n in g_new.nodes()]
                g_tmp_list = [g_new.copy()]
                for nd1i in range(nx.number_of_nodes(g_new)):
                    nd1 = nd_list[
                        nd1i]  # @todo: not just edges, but all pairs of nodes
                    for nd2i in range(nd1i + 1, nx.number_of_nodes(g_new)):
                        nd2 = nd_list[nd2i]
                        #                for nd1, nd2, _ in g_new.edges(data=True):
                        h_ij0_list = []
                        label_list = []
                        for label in edge_label_set:
                            h_ij0 = 0
                            for idx, g in enumerate(Gn_median):
                                pi_i = pi_p_forward[idx][nd1i]
                                pi_j = pi_p_forward[idx][nd2i]
                                h_ij0_p = (g.has_node(pi_i)
                                           and g.has_node(pi_j)
                                           and g.has_edge(pi_i, pi_j)
                                           and g.edges[pi_i, pi_j][edge_label]
                                           == label)
                                h_ij0 += h_ij0_p
                            h_ij0_list.append(h_ij0)
                            label_list.append(label)

                        # get the best labels.
                        idx_max = np.argwhere(h_ij0_list == np.max(
                            h_ij0_list)).flatten().tolist()
                        if allBestEdges:  # choose all best graphs.
                            elabel_best = [label_list[idx] for idx in idx_max]
                            h_ij0_max = [h_ij0_list[idx] for idx in idx_max]
                            # generate "best" graphs with regard to "best" node labels.
                            G_new_list_ed = []
                            for g_tmp in g_tmp_list:  # @todo: seems it can be simplified. The G_new_list will only contain 1 graph for now.
                                for idxl, el in enumerate(elabel_best):
                                    g_tmp_copy = g_tmp.copy()
                                    # check whether a_ij is 0 or 1.
                                    sij_norm = 0
                                    for idx, g in enumerate(Gn_median):
                                        pi_i = pi_p_forward[idx][nd1i]
                                        pi_j = pi_p_forward[idx][nd2i]
                                        if g.has_node(pi_i) and g.has_node(pi_j) and \
                                            g.has_edge(pi_i, pi_j):
                                            sij_norm += 1
                                    if h_ij0_max[idxl] > len(Gn_median) * c_er / c_es + \
                                        sij_norm * (1 - (c_er + c_ei) / c_es):
                                        if not g_tmp_copy.has_edge(nd1, nd2):
                                            g_tmp_copy.add_edge(nd1, nd2)
                                        g_tmp_copy.edges[nd1, nd2][
                                            edge_label] = elabel_best[idxl]
                                    else:
                                        if g_tmp_copy.has_edge(nd1, nd2):
                                            g_tmp_copy.remove_edge(nd1, nd2)
                                    G_new_list_ed.append(g_tmp_copy)
                            g_tmp_list = [ggg.copy() for ggg in G_new_list_ed]
                        else:  # choose one of the best randomly.
                            idx_rdm = random.randint(0, len(idx_max) - 1)
                            best_label = label_list[idx_max[idx_rdm]]
                            h_ij0_max = h_ij0_list[idx_max[idx_rdm]]

                            # check whether a_ij is 0 or 1.
                            sij_norm = 0
                            for idx, g in enumerate(Gn_median):
                                pi_i = pi_p_forward[idx][nd1i]
                                pi_j = pi_p_forward[idx][nd2i]
                                if g.has_node(pi_i) and g.has_node(
                                        pi_j) and g.has_edge(pi_i, pi_j):
                                    sij_norm += 1
                            if h_ij0_max > len(
                                    Gn_median) * c_er / c_es + sij_norm * (
                                        1 - (c_er + c_ei) / c_es):
                                if not g_new.has_edge(nd1, nd2):
                                    g_new.add_edge(nd1, nd2)
                                g_new.edges[nd1, nd2][edge_label] = best_label
                            else:
                                #                            elif h_ij0_max < len(Gn_median) * c_er / c_es + sij_norm * (1 - (c_er + c_ei) / c_es):
                                if g_new.has_edge(nd1, nd2):
                                    g_new.remove_edge(nd1, nd2)
                            g_tmp_list = [g_new]
                G_new_list_edge += g_tmp_list
            G_new_list = [ggg.copy() for ggg in G_new_list_edge]

        else:  # if edges are unlabeled
            # @todo: is this even right? G or g_tmp? check if the new one is right
            # @todo: works only for undirected graphs.

            for g_tmp in G_new_list:
                nd_list = [n for n in g_tmp.nodes()]
                for nd1i in range(nx.number_of_nodes(g_tmp)):
                    nd1 = nd_list[nd1i]
                    for nd2i in range(nd1i + 1, nx.number_of_nodes(g_tmp)):
                        nd2 = nd_list[nd2i]
                        sij_norm = 0
                        for idx, g in enumerate(Gn_median):
                            pi_i = pi_p_forward[idx][nd1i]
                            pi_j = pi_p_forward[idx][nd2i]
                            if g.has_node(pi_i) and g.has_node(
                                    pi_j) and g.has_edge(pi_i, pi_j):
                                sij_norm += 1
                        if sij_norm > len(Gn_median) * c_er / (c_er + c_ei):
                            # @todo: should we consider if nd1 and nd2 in g_tmp?
                            # or just add the edge anyway?
                            if g_tmp.has_node(nd1) and g_tmp.has_node(nd2) \
                                and not g_tmp.has_edge(nd1, nd2):
                                g_tmp.add_edge(nd1, nd2)
                        else:  # @todo: which to use?
                            #                        elif sij_norm < len(Gn_median) * c_er / (c_er + c_ei):
                            if g_tmp.has_edge(nd1, nd2):
                                g_tmp.remove_edge(nd1, nd2)
                        # do not change anything when equal.

#        for i, g in enumerate(G_new_list):
#            import matplotlib.pyplot as plt
#            nx.draw(g, labels=nx.get_node_attributes(g, 'atom'), with_labels=True)
##            plt.savefig("results/gk_iam/simple_two/xx" + str(i) + ".png", format="PNG")
#            plt.show()
#            print(g.nodes(data=True))
#            print(g.edges(data=True))

#        # find the best graph generated in this iteration and update pi_p.
# @todo: should we update all graphs generated or just the best ones?
        dis_list, pi_forward_list = ged_median(G_new_list,
                                               Gn_median,
                                               params_ged=params_ged)
        # @todo: should we remove the identical and connectivity check?
        # Don't know which is faster.
        if ds_attrs['node_attr_dim'] == 0 and ds_attrs['edge_attr_dim'] == 0:
            G_new_list, idx_list = remove_duplicates(G_new_list)
            pi_forward_list = [pi_forward_list[idx] for idx in idx_list]
            dis_list = [dis_list[idx] for idx in idx_list]
#        if connected == True:
#            G_new_list, idx_list = remove_disconnected(G_new_list)
#            pi_forward_list = [pi_forward_list[idx] for idx in idx_list]
#        idx_min_list = np.argwhere(dis_list == np.min(dis_list)).flatten().tolist()
#        dis_min = dis_list[idx_min_tmp_list[0]]
#        pi_forward_list = [pi_forward_list[idx] for idx in idx_min_list]
#        G_new_list = [G_new_list[idx] for idx in idx_min_list]

#        for g in G_new_list:
#            import matplotlib.pyplot as plt
#            nx.draw_networkx(g)
#            plt.show()
#            print(g.nodes(data=True))
#            print(g.edges(data=True))

        return G_new_list, pi_forward_list, dis_list
Exemple #3
0
def iam_upgraded(
    Gn_median,
    Gn_candidate,
    c_ei=3,
    c_er=3,
    c_es=1,
    ite_max=50,
    epsilon=0.001,
    node_label='atom',
    edge_label='bond_type',
    connected=False,
    removeNodes=True,
    allBestInit=False,
    allBestNodes=False,
    allBestEdges=False,
    allBestOutput=False,
    params_ged={
        'lib':
        'gedlibpy',
        'cost':
        'CHEM_1',
        'method':
        'IPFP',
        'edit_cost_constant': [],
        'stabilizer':
        None,
        'algo_options':
        '--threads 8 --initial-solutions 40 --ratio-runs-from-initial-solutions 1'
    }):
    """See my name, then you know what I do.
    """
    #    Gn_median = Gn_median[0:10]
    #    Gn_median = [nx.convert_node_labels_to_integers(g) for g in Gn_median]
    node_ir = np.inf  # corresponding to the node remove and insertion.
    label_r = 'thanksdanny'  # the label for node remove. # @todo: make this label unrepeatable.
    ds_attrs = get_dataset_attributes(
        Gn_median + Gn_candidate,
        attr_names=['edge_labeled', 'node_attr_dim', 'edge_attr_dim'],
        edge_label=edge_label)
    node_label_set = get_node_labels(Gn_median, node_label)
    edge_label_set = get_edge_labels(Gn_median, edge_label)

    def generate_graph(G, pi_p_forward):
        G_new_list = [G.copy()
                      ]  # all "best" graphs generated in this iteration.
        #        nx.draw_networkx(G)
        #        import matplotlib.pyplot as plt
        #        plt.show()
        #        print(pi_p_forward)

        # update vertex labels.
        # pre-compute h_i0 for each label.
        #        for label in get_node_labels(Gn, node_label):
        #            print(label)
        #        for nd in G.nodes(data=True):
        #            pass
        if not ds_attrs['node_attr_dim']:  # labels are symbolic
            for ndi, (nd, _) in enumerate(G.nodes(data=True)):
                h_i0_list = []
                label_list = []
                for label in node_label_set:
                    h_i0 = 0
                    for idx, g in enumerate(Gn_median):
                        pi_i = pi_p_forward[idx][ndi]
                        if pi_i != node_ir and g.nodes[pi_i][
                                node_label] == label:
                            h_i0 += 1
                    h_i0_list.append(h_i0)
                    label_list.append(label)
                # case when the node is to be removed.
                if removeNodes:
                    h_i0_remove = 0  # @todo: maybe this can be added to the node_label_set above.
                    for idx, g in enumerate(Gn_median):
                        pi_i = pi_p_forward[idx][ndi]
                        if pi_i == node_ir:
                            h_i0_remove += 1
                    h_i0_list.append(h_i0_remove)
                    label_list.append(label_r)
                # get the best labels.
                idx_max = np.argwhere(
                    h_i0_list == np.max(h_i0_list)).flatten().tolist()
                if allBestNodes:  # choose all best graphs.
                    nlabel_best = [label_list[idx] for idx in idx_max]
                    # generate "best" graphs with regard to "best" node labels.
                    G_new_list_nd = []
                    for g in G_new_list:  # @todo: seems it can be simplified. The G_new_list will only contain 1 graph for now.
                        for nl in nlabel_best:
                            g_tmp = g.copy()
                            if nl == label_r:
                                g_tmp.remove_node(nd)
                            else:
                                g_tmp.nodes[nd][node_label] = nl
                            G_new_list_nd.append(g_tmp)

    #                            nx.draw_networkx(g_tmp)
    #                            import matplotlib.pyplot as plt
    #                            plt.show()
    #                            print(g_tmp.nodes(data=True))
    #                            print(g_tmp.edges(data=True))
                    G_new_list = [ggg.copy() for ggg in G_new_list_nd]
                else:
                    # choose one of the best randomly.
                    idx_rdm = random.randint(0, len(idx_max) - 1)
                    best_label = label_list[idx_max[idx_rdm]]
                    h_i0_max = h_i0_list[idx_max[idx_rdm]]

                    g_new = G_new_list[0]
                    if best_label == label_r:
                        g_new.remove_node(nd)
                    else:
                        g_new.nodes[nd][node_label] = best_label
                    G_new_list = [g_new]
        else:  # labels are non-symbolic
            for ndi, (nd, _) in enumerate(G.nodes(data=True)):
                Si_norm = 0
                phi_i_bar = np.array(
                    [0.0 for _ in range(ds_attrs['node_attr_dim'])])
                for idx, g in enumerate(Gn_median):
                    pi_i = pi_p_forward[idx][ndi]
                    if g.has_node(
                            pi_i
                    ):  #@todo: what if no g has node? phi_i_bar = 0?
                        Si_norm += 1
                        phi_i_bar += np.array([
                            float(itm) for itm in g.nodes[pi_i]['attributes']
                        ])
                phi_i_bar /= Si_norm
                G_new_list[0].nodes[nd]['attributes'] = phi_i_bar

#        for g in G_new_list:
#            import matplotlib.pyplot as plt
#            nx.draw(g, labels=nx.get_node_attributes(g, 'atom'), with_labels=True)
#            plt.show()
#            print(g.nodes(data=True))
#            print(g.edges(data=True))

# update edge labels and adjacency matrix.
        if ds_attrs['edge_labeled']:
            G_new_list_edge = []
            for g_new in G_new_list:
                nd_list = [n for n in g_new.nodes()]
                g_tmp_list = [g_new.copy()]
                for nd1i in range(nx.number_of_nodes(g_new)):
                    nd1 = nd_list[
                        nd1i]  # @todo: not just edges, but all pairs of nodes
                    for nd2i in range(nd1i + 1, nx.number_of_nodes(g_new)):
                        nd2 = nd_list[nd2i]
                        #                for nd1, nd2, _ in g_new.edges(data=True):
                        h_ij0_list = []
                        label_list = []
                        for label in edge_label_set:
                            h_ij0 = 0
                            for idx, g in enumerate(Gn_median):
                                pi_i = pi_p_forward[idx][nd1i]
                                pi_j = pi_p_forward[idx][nd2i]
                                h_ij0_p = (g.has_node(pi_i)
                                           and g.has_node(pi_j)
                                           and g.has_edge(pi_i, pi_j)
                                           and g.edges[pi_i, pi_j][edge_label]
                                           == label)
                                h_ij0 += h_ij0_p
                            h_ij0_list.append(h_ij0)
                            label_list.append(label)

                        # get the best labels.
                        idx_max = np.argwhere(h_ij0_list == np.max(
                            h_ij0_list)).flatten().tolist()
                        if allBestEdges:  # choose all best graphs.
                            elabel_best = [label_list[idx] for idx in idx_max]
                            h_ij0_max = [h_ij0_list[idx] for idx in idx_max]
                            # generate "best" graphs with regard to "best" node labels.
                            G_new_list_ed = []
                            for g_tmp in g_tmp_list:  # @todo: seems it can be simplified. The G_new_list will only contain 1 graph for now.
                                for idxl, el in enumerate(elabel_best):
                                    g_tmp_copy = g_tmp.copy()
                                    # check whether a_ij is 0 or 1.
                                    sij_norm = 0
                                    for idx, g in enumerate(Gn_median):
                                        pi_i = pi_p_forward[idx][nd1i]
                                        pi_j = pi_p_forward[idx][nd2i]
                                        if g.has_node(pi_i) and g.has_node(pi_j) and \
                                            g.has_edge(pi_i, pi_j):
                                            sij_norm += 1
                                    if h_ij0_max[idxl] > len(Gn_median) * c_er / c_es + \
                                        sij_norm * (1 - (c_er + c_ei) / c_es):
                                        if not g_tmp_copy.has_edge(nd1, nd2):
                                            g_tmp_copy.add_edge(nd1, nd2)
                                        g_tmp_copy.edges[nd1, nd2][
                                            edge_label] = elabel_best[idxl]
                                    else:
                                        if g_tmp_copy.has_edge(nd1, nd2):
                                            g_tmp_copy.remove_edge(nd1, nd2)
                                    G_new_list_ed.append(g_tmp_copy)
                            g_tmp_list = [ggg.copy() for ggg in G_new_list_ed]
                        else:  # choose one of the best randomly.
                            idx_rdm = random.randint(0, len(idx_max) - 1)
                            best_label = label_list[idx_max[idx_rdm]]
                            h_ij0_max = h_ij0_list[idx_max[idx_rdm]]

                            # check whether a_ij is 0 or 1.
                            sij_norm = 0
                            for idx, g in enumerate(Gn_median):
                                pi_i = pi_p_forward[idx][nd1i]
                                pi_j = pi_p_forward[idx][nd2i]
                                if g.has_node(pi_i) and g.has_node(
                                        pi_j) and g.has_edge(pi_i, pi_j):
                                    sij_norm += 1
                            if h_ij0_max > len(
                                    Gn_median) * c_er / c_es + sij_norm * (
                                        1 - (c_er + c_ei) / c_es):
                                if not g_new.has_edge(nd1, nd2):
                                    g_new.add_edge(nd1, nd2)
                                g_new.edges[nd1, nd2][edge_label] = best_label
                            else:
                                #                            elif h_ij0_max < len(Gn_median) * c_er / c_es + sij_norm * (1 - (c_er + c_ei) / c_es):
                                if g_new.has_edge(nd1, nd2):
                                    g_new.remove_edge(nd1, nd2)
                            g_tmp_list = [g_new]
                G_new_list_edge += g_tmp_list
            G_new_list = [ggg.copy() for ggg in G_new_list_edge]

        else:  # if edges are unlabeled
            # @todo: is this even right? G or g_tmp? check if the new one is right
            # @todo: works only for undirected graphs.

            for g_tmp in G_new_list:
                nd_list = [n for n in g_tmp.nodes()]
                for nd1i in range(nx.number_of_nodes(g_tmp)):
                    nd1 = nd_list[nd1i]
                    for nd2i in range(nd1i + 1, nx.number_of_nodes(g_tmp)):
                        nd2 = nd_list[nd2i]
                        sij_norm = 0
                        for idx, g in enumerate(Gn_median):
                            pi_i = pi_p_forward[idx][nd1i]
                            pi_j = pi_p_forward[idx][nd2i]
                            if g.has_node(pi_i) and g.has_node(
                                    pi_j) and g.has_edge(pi_i, pi_j):
                                sij_norm += 1
                        if sij_norm > len(Gn_median) * c_er / (c_er + c_ei):
                            # @todo: should we consider if nd1 and nd2 in g_tmp?
                            # or just add the edge anyway?
                            if g_tmp.has_node(nd1) and g_tmp.has_node(nd2) \
                                and not g_tmp.has_edge(nd1, nd2):
                                g_tmp.add_edge(nd1, nd2)
                        else:  # @todo: which to use?
                            #                        elif sij_norm < len(Gn_median) * c_er / (c_er + c_ei):
                            if g_tmp.has_edge(nd1, nd2):
                                g_tmp.remove_edge(nd1, nd2)
                        # do not change anything when equal.

#        for i, g in enumerate(G_new_list):
#            import matplotlib.pyplot as plt
#            nx.draw(g, labels=nx.get_node_attributes(g, 'atom'), with_labels=True)
##            plt.savefig("results/gk_iam/simple_two/xx" + str(i) + ".png", format="PNG")
#            plt.show()
#            print(g.nodes(data=True))
#            print(g.edges(data=True))

#        # find the best graph generated in this iteration and update pi_p.
# @todo: should we update all graphs generated or just the best ones?
        dis_list, pi_forward_list = ged_median(G_new_list,
                                               Gn_median,
                                               params_ged=params_ged)
        # @todo: should we remove the identical and connectivity check?
        # Don't know which is faster.
        if ds_attrs['node_attr_dim'] == 0 and ds_attrs['edge_attr_dim'] == 0:
            G_new_list, idx_list = remove_duplicates(G_new_list)
            pi_forward_list = [pi_forward_list[idx] for idx in idx_list]
            dis_list = [dis_list[idx] for idx in idx_list]
#        if connected == True:
#            G_new_list, idx_list = remove_disconnected(G_new_list)
#            pi_forward_list = [pi_forward_list[idx] for idx in idx_list]
#        idx_min_list = np.argwhere(dis_list == np.min(dis_list)).flatten().tolist()
#        dis_min = dis_list[idx_min_tmp_list[0]]
#        pi_forward_list = [pi_forward_list[idx] for idx in idx_min_list]
#        G_new_list = [G_new_list[idx] for idx in idx_min_list]

#        for g in G_new_list:
#            import matplotlib.pyplot as plt
#            nx.draw_networkx(g)
#            plt.show()
#            print(g.nodes(data=True))
#            print(g.edges(data=True))

        return G_new_list, pi_forward_list, dis_list

    def best_median_graphs(Gn_candidate, pi_all_forward, dis_all):
        idx_min_list = np.argwhere(
            dis_all == np.min(dis_all)).flatten().tolist()
        dis_min = dis_all[idx_min_list[0]]
        pi_forward_min_list = [pi_all_forward[idx] for idx in idx_min_list]
        G_min_list = [Gn_candidate[idx] for idx in idx_min_list]
        return G_min_list, pi_forward_min_list, dis_min

    def iteration_proc(G, pi_p_forward, cur_sod):
        G_list = [G]
        pi_forward_list = [pi_p_forward]
        old_sod = cur_sod * 2
        sod_list = [cur_sod]
        dis_list = [cur_sod]
        # iterations.
        itr = 0
        # @todo: what if difference == 0?
        #        while itr < ite_max and (np.abs(old_sod - cur_sod) > epsilon or
        #                                 np.abs(old_sod - cur_sod) == 0):
        while itr < ite_max and np.abs(old_sod - cur_sod) > epsilon:
            #        while itr < ite_max:
            #        for itr in range(0, 5): # the convergence condition?
            print('itr_iam is', itr)
            G_new_list = []
            pi_forward_new_list = []
            dis_new_list = []
            for idx, g in enumerate(G_list):
                #                label_set = get_node_labels(Gn_median + [g], node_label)
                G_tmp_list, pi_forward_tmp_list, dis_tmp_list = generate_graph(
                    g, pi_forward_list[idx])
                G_new_list += G_tmp_list
                pi_forward_new_list += pi_forward_tmp_list
                dis_new_list += dis_tmp_list
            # @todo: need to remove duplicates here?
            G_list = [ggg.copy() for ggg in G_new_list]
            pi_forward_list = [pitem.copy() for pitem in pi_forward_new_list]
            dis_list = dis_new_list[:]

            old_sod = cur_sod
            cur_sod = np.min(dis_list)
            sod_list.append(cur_sod)

            itr += 1

        # @todo: do we return all graphs or the best ones?
        # get the best ones of the generated graphs.
        G_list, pi_forward_list, dis_min = best_median_graphs(
            G_list, pi_forward_list, dis_list)

        if ds_attrs['node_attr_dim'] == 0 and ds_attrs['edge_attr_dim'] == 0:
            G_list, idx_list = remove_duplicates(G_list)
            pi_forward_list = [pi_forward_list[idx] for idx in idx_list]
#            dis_list = [dis_list[idx] for idx in idx_list]

#        import matplotlib.pyplot as plt
#        for g in G_list:
#            nx.draw_networkx(g)
#            plt.show()
#            print(g.nodes(data=True))
#            print(g.edges(data=True))

        print('\nsods:', sod_list, '\n')

        return G_list, pi_forward_list, dis_min, sod_list

    def remove_duplicates(Gn):
        """Remove duplicate graphs from list.
        """
        Gn_new = []
        idx_list = []
        for idx, g in enumerate(Gn):
            dupl = False
            for g_new in Gn_new:
                if graph_isIdentical(g_new, g):
                    dupl = True
                    break
            if not dupl:
                Gn_new.append(g)
                idx_list.append(idx)
        return Gn_new, idx_list

    def remove_disconnected(Gn):
        """Remove disconnected graphs from list.
        """
        Gn_new = []
        idx_list = []
        for idx, g in enumerate(Gn):
            if nx.is_connected(g):
                Gn_new.append(g)
                idx_list.append(idx)
        return Gn_new, idx_list

    ###########################################################################

    # phase 1: initilize.
    # compute set-median.
    dis_min = np.inf
    dis_list, pi_forward_all = ged_median(Gn_candidate,
                                          Gn_median,
                                          params_ged=params_ged,
                                          parallel=True)
    print('finish computing GEDs.')
    # find all smallest distances.
    if allBestInit:  # try all best init graphs.
        idx_min_list = range(len(dis_list))
        dis_min = dis_list
    else:
        idx_min_list = np.argwhere(
            dis_list == np.min(dis_list)).flatten().tolist()
        dis_min = [dis_list[idx_min_list[0]]] * len(idx_min_list)
        idx_min_rdm = random.randint(0, len(idx_min_list) - 1)
        idx_min_list = [idx_min_list[idx_min_rdm]]
    sod_set_median = np.min(dis_min)

    # phase 2: iteration.
    G_list = []
    dis_list = []
    pi_forward_list = []
    G_set_median_list = []
    #    sod_list = []
    for idx_tmp, idx_min in enumerate(idx_min_list):
        #        print('idx_min is', idx_min)
        G = Gn_candidate[idx_min].copy()
        G_set_median_list.append(G.copy())
        # list of edit operations.
        pi_p_forward = pi_forward_all[idx_min]
        #        pi_p_backward = pi_all_backward[idx_min]
        Gi_list, pi_i_forward_list, dis_i_min, sod_list = iteration_proc(
            G, pi_p_forward, dis_min[idx_tmp])
        G_list += Gi_list
        dis_list += [dis_i_min] * len(Gi_list)
        pi_forward_list += pi_i_forward_list

    if ds_attrs['node_attr_dim'] == 0 and ds_attrs['edge_attr_dim'] == 0:
        G_list, idx_list = remove_duplicates(G_list)
        dis_list = [dis_list[idx] for idx in idx_list]
        pi_forward_list = [pi_forward_list[idx] for idx in idx_list]
    if connected == True:
        G_list_con, idx_list = remove_disconnected(G_list)
        # if there is no connected graphs at all, then remain the disconnected ones.
        if len(G_list_con) > 0:  # @todo: ??????????????????????????
            G_list = G_list_con
            dis_list = [dis_list[idx] for idx in idx_list]
            pi_forward_list = [pi_forward_list[idx] for idx in idx_list]


#    import matplotlib.pyplot as plt
#    for g in G_list:
#        nx.draw_networkx(g)
#        plt.show()
#        print(g.nodes(data=True))
#        print(g.edges(data=True))

# get the best median graphs
    G_gen_median_list, pi_forward_min_list, sod_gen_median = best_median_graphs(
        G_list, pi_forward_list, dis_list)
    #    for g in G_gen_median_list:
    #        nx.draw_networkx(g)
    #        plt.show()
    #        print(g.nodes(data=True))
    #        print(g.edges(data=True))

    if not allBestOutput:
        # randomly choose one graph.
        idx_rdm = random.randint(0, len(G_gen_median_list) - 1)
        G_gen_median_list = [G_gen_median_list[idx_rdm]]

    return G_gen_median_list, sod_gen_median, sod_list, G_set_median_list, sod_set_median
def test_gkiam_2combination():
    from gk_iam import gk_iam_nearest_multi
    ds = {
        'name': 'MUTAG',
        'dataset': '../datasets/MUTAG/MUTAG_A.txt',
        'extra_params': {}
    }  # node/edge symb
    Gn, y_all = loadDataset(ds['dataset'], extra_params=ds['extra_params'])
    #    Gn = Gn[0:50]
    remove_edges(Gn)
    gkernel = 'marginalizedkernel'

    lmbda = 0.03  # termination probalility
    r_max = 10  # iteration limit for pre-image.
    alpha_range = np.linspace(0.5, 0.5, 1)
    k = 20  # k nearest neighbors
    epsilon = 1e-6
    ged_cost = 'CHEM_1'
    ged_method = 'IPFP'
    saveGXL = 'gedlib'
    c_ei = 1
    c_er = 1
    c_es = 1

    # randomly select two molecules
    np.random.seed(1)
    idx_gi = [10, 11]  # np.random.randint(0, len(Gn), 2)
    g1 = Gn[idx_gi[0]].copy()
    g2 = Gn[idx_gi[1]].copy()
    #    Gn[10] = []
    #    Gn[10] = []

    #    nx.draw(g1, labels=nx.get_node_attributes(g1, 'atom'), with_labels=True)
    #    plt.savefig("results/random_preimage/mutag10.png", format="PNG")
    #    plt.show()
    #    nx.draw(g2, labels=nx.get_node_attributes(g2, 'atom'), with_labels=True)
    #    plt.savefig("results/random_preimage/mutag11.png", format="PNG")
    #    plt.show()

    Gn_mix = [g.copy() for g in Gn]
    Gn_mix.append(g1.copy())
    Gn_mix.append(g2.copy())

    # compute
    #    time0 = time.time()
    #    km = compute_kernel(Gn_mix, gkernel, True)
    #    time_km = time.time() - time0

    # write Gram matrix to file and read it.
    #    np.savez('results/gram_matrix.gm', gm=km, gmtime=time_km)
    gmfile = np.load('results/gram_matrix.gm.npz')
    km = gmfile['gm']
    time_km = gmfile['gmtime']

    time_list = []
    dis_ks_min_list = []
    sod_gs_list = []
    sod_gs_min_list = []
    nb_updated_list = []
    g_best = []
    # for each alpha
    for alpha in alpha_range:
        print('\n-------------------------------------------------------\n')
        print('alpha =', alpha)
        time0 = time.time()
        dhat, ghat_list, sod_ks, nb_updated = gk_iam_nearest_multi(
            Gn, [g1, g2], [alpha, 1 - alpha],
            range(len(Gn),
                  len(Gn) + 2),
            km,
            k,
            r_max,
            gkernel,
            c_ei=c_ei,
            c_er=c_er,
            c_es=c_es,
            epsilon=epsilon,
            ged_cost=ged_cost,
            ged_method=ged_method,
            saveGXL=saveGXL)
        time_total = time.time() - time0 + time_km
        print('time: ', time_total)
        time_list.append(time_total)
        dis_ks_min_list.append(dhat)
        g_best.append(ghat_list)
        nb_updated_list.append(nb_updated)

    # show best graphs and save them to file.
    for idx, item in enumerate(alpha_range):
        print('when alpha is', item, 'the shortest distance is',
              dis_ks_min_list[idx])
        print('one of the possible corresponding pre-images is')
        nx.draw(g_best[idx][0],
                labels=nx.get_node_attributes(g_best[idx][0], 'atom'),
                with_labels=True)
        plt.savefig('results/gk_iam/mutag_alpha' + str(item) + '.png',
                    format="PNG")
        plt.show()
        print(g_best[idx][0].nodes(data=True))
        print(g_best[idx][0].edges(data=True))

#        for g in g_best[idx]:
#            draw_Letter_graph(g, savepath='results/gk_iam/')
##            nx.draw_networkx(g)
##            plt.show()
#            print(g.nodes(data=True))
#            print(g.edges(data=True))

# compute the corresponding sod in graph space.
    for idx, item in enumerate(alpha_range):
        sod_tmp, _ = ged_median([g_best[0]], [g1, g2],
                                ged_cost=ged_cost,
                                ged_method=ged_method,
                                saveGXL=saveGXL)
        sod_gs_list.append(sod_tmp)
        sod_gs_min_list.append(np.min(sod_tmp))

    print('\nsods in graph space: ', sod_gs_list)
    print('\nsmallest sod in graph space for each alpha: ', sod_gs_min_list)
    print('\nsmallest distance in kernel space for each alpha: ',
          dis_ks_min_list)
    print('\nnumber of updates for each alpha: ', nb_updated_list)
    print('\ntimes:', time_list)
def test_gkiam_2combination_all_pairs():
    ds = {
        'name': 'MUTAG',
        'dataset': '../datasets/MUTAG/MUTAG_A.txt',
        'extra_params': {}
    }  # node/edge symb
    Gn, y_all = loadDataset(ds['dataset'], extra_params=ds['extra_params'])
    #    Gn = Gn[0:50]
    remove_edges(Gn)
    gkernel = 'marginalizedkernel'

    lmbda = 0.03  # termination probalility
    r_max = 10  # iteration limit for pre-image.
    alpha_range = np.linspace(0.5, 0.5, 1)
    k = 5  # k nearest neighbors
    epsilon = 1e-6
    InitIAMWithAllDk = False
    # parameters for GED function
    ged_cost = 'CHEM_1'
    ged_method = 'IPFP'
    saveGXL = 'gedlib'
    # parameters for IAM function
    c_ei = 1
    c_er = 1
    c_es = 1
    ite_max_iam = 50
    epsilon_iam = 0.001
    removeNodes = True
    connected_iam = False

    nb_update_mat = np.full((len(Gn), len(Gn)), np.inf)
    # test on each pair of graphs.
    #    for idx1 in range(len(Gn) - 1, -1, -1):
    #        for idx2 in range(idx1, -1, -1):
    for idx1 in range(187, 188):
        for idx2 in range(167, 168):
            g1 = Gn[idx1].copy()
            g2 = Gn[idx2].copy()
            #    Gn[10] = []
            #    Gn[10] = []

            nx.draw(g1,
                    labels=nx.get_node_attributes(g1, 'atom'),
                    with_labels=True)
            plt.savefig("results/gk_iam/all_pairs/mutag187.png", format="PNG")
            plt.show()
            plt.clf()
            nx.draw(g2,
                    labels=nx.get_node_attributes(g2, 'atom'),
                    with_labels=True)
            plt.savefig("results/gk_iam/all_pairs/mutag167.png", format="PNG")
            plt.show()
            plt.clf()

            ###################################################################
            #            Gn_mix = [g.copy() for g in Gn]
            #            Gn_mix.append(g1.copy())
            #            Gn_mix.append(g2.copy())
            #
            #            # compute
            #            time0 = time.time()
            #            km = compute_kernel(Gn_mix, gkernel, True)
            #            time_km = time.time() - time0
            #
            #            # write Gram matrix to file and read it.
            #            np.savez('results/gram_matrix_uhpath_itr7_pq0.8.gm', gm=km, gmtime=time_km)

            ###################################################################
            gmfile = np.load('results/gram_matrix_marg_itr10_pq0.03.gm.npz')
            km = gmfile['gm']
            time_km = gmfile['gmtime']
            # modify mixed gram matrix.
            for i in range(len(Gn)):
                km[i, len(Gn)] = km[i, idx1]
                km[i, len(Gn) + 1] = km[i, idx2]
                km[len(Gn), i] = km[i, idx1]
                km[len(Gn) + 1, i] = km[i, idx2]
            km[len(Gn), len(Gn)] = km[idx1, idx1]
            km[len(Gn), len(Gn) + 1] = km[idx1, idx2]
            km[len(Gn) + 1, len(Gn)] = km[idx2, idx1]
            km[len(Gn) + 1, len(Gn) + 1] = km[idx2, idx2]

            ###################################################################
            #            # use only the two graphs in median set as candidates.
            #            Gn = [g1.copy(), g2.copy()]
            #            Gn_mix = Gn + [g1.copy(), g2.copy()]
            #            # compute
            #            time0 = time.time()
            #            km = compute_kernel(Gn_mix, gkernel, True)
            #            time_km = time.time() - time0

            time_list = []
            dis_ks_min_list = []
            sod_gs_list = []
            sod_gs_min_list = []
            nb_updated_list = []
            nb_updated_k_list = []
            g_best = []
            # for each alpha
            for alpha in alpha_range:
                print(
                    '\n-------------------------------------------------------\n'
                )
                print('alpha =', alpha)
                time0 = time.time()
                dhat, ghat_list, sod_ks, nb_updated, nb_updated_k = \
                    preimage_iam(Gn, [g1, g2],
                    [alpha, 1 - alpha], range(len(Gn), len(Gn) + 2), km, k, r_max,
                    gkernel, epsilon=epsilon, InitIAMWithAllDk=InitIAMWithAllDk,
                    params_iam={'c_ei': c_ei, 'c_er': c_er, 'c_es': c_es,
                                'ite_max': ite_max_iam, 'epsilon': epsilon_iam,
                                'removeNodes': removeNodes, 'connected': connected_iam},
                    params_ged={'ged_cost': ged_cost, 'ged_method': ged_method,
                                'saveGXL': saveGXL})
                time_total = time.time() - time0 + time_km
                print('time: ', time_total)
                time_list.append(time_total)
                dis_ks_min_list.append(dhat)
                g_best.append(ghat_list)
                nb_updated_list.append(nb_updated)
                nb_updated_k_list.append(nb_updated_k)

            # show best graphs and save them to file.
            for idx, item in enumerate(alpha_range):
                print('when alpha is', item, 'the shortest distance is',
                      dis_ks_min_list[idx])
                print('one of the possible corresponding pre-images is')
                nx.draw(g_best[idx][0],
                        labels=nx.get_node_attributes(g_best[idx][0], 'atom'),
                        with_labels=True)
                plt.savefig('results/gk_iam/mutag' + str(idx1) + '_' +
                            str(idx2) + '_alpha' + str(item) + '.png',
                            format="PNG")
                #                plt.show()
                plt.clf()


#                print(g_best[idx][0].nodes(data=True))
#                print(g_best[idx][0].edges(data=True))

#        for g in g_best[idx]:
#            draw_Letter_graph(g, savepath='results/gk_iam/')
##            nx.draw_networkx(g)
##            plt.show()
#            print(g.nodes(data=True))
#            print(g.edges(data=True))

# compute the corresponding sod in graph space.
            for idx, item in enumerate(alpha_range):
                sod_tmp, _ = ged_median([g_best[0]], [g1, g2],
                                        ged_cost=ged_cost,
                                        ged_method=ged_method,
                                        saveGXL=saveGXL)
                sod_gs_list.append(sod_tmp)
                sod_gs_min_list.append(np.min(sod_tmp))

            print('\nsods in graph space: ', sod_gs_list)
            print('\nsmallest sod in graph space for each alpha: ',
                  sod_gs_min_list)
            print('\nsmallest distance in kernel space for each alpha: ',
                  dis_ks_min_list)
            print('\nnumber of updates of the best graph for each alpha: ',
                  nb_updated_list)
            print(
                '\nnumber of updates of the k nearest graphs for each alpha: ',
                nb_updated_k_list)
            print('\ntimes:', time_list)
            nb_update_mat[idx1, idx2] = nb_updated_list[0]

            str_fw = 'graphs %d and %d: %d.\n' % (idx1, idx2,
                                                  nb_updated_list[0])
            with open('results/gk_iam/all_pairs/nb_updates.txt', 'r+') as file:
                content = file.read()
                file.seek(0, 0)
                file.write(str_fw + content)
def test_preimage_iam_median_nb():
    ds = {
        'name': 'MUTAG',
        'dataset': '../datasets/MUTAG/MUTAG_A.txt',
        'extra_params': {}
    }  # node/edge symb
    Gn, y_all = loadDataset(ds['dataset'], extra_params=ds['extra_params'])
    #    Gn = Gn[0:50]
    remove_edges(Gn)
    gkernel = 'marginalizedkernel'

    lmbda = 0.03  # termination probalility
    r_max = 3  # iteration limit for pre-image.
    #    alpha_range = np.linspace(0.5, 0.5, 1)
    k = 5  # k nearest neighbors
    epsilon = 1e-6
    InitIAMWithAllDk = True
    # parameters for IAM function
    #    c_vi = 0.037
    #    c_vr = 0.038
    #    c_vs = 0.075
    #    c_ei = 0.001
    #    c_er = 0.001
    #    c_es = 0.0
    c_vi = 4
    c_vr = 4
    c_vs = 2
    c_ei = 1
    c_er = 1
    c_es = 1
    ite_max_iam = 50
    epsilon_iam = 0.001
    removeNodes = True
    connected_iam = False
    # parameters for GED function
    #    ged_cost='CHEM_1'
    ged_cost = 'CONSTANT'
    ged_method = 'IPFP'
    edit_cost_constant = [c_vi, c_vr, c_vs, c_ei, c_er, c_es]
    ged_stabilizer = 'min'
    ged_repeat = 50
    params_ged = {
        'lib': 'gedlibpy',
        'cost': ged_cost,
        'method': ged_method,
        'edit_cost_constant': edit_cost_constant,
        'stabilizer': ged_stabilizer,
        'repeat': ged_repeat
    }

    # number of graphs; we what to compute the median of these graphs.
    #    nb_median_range = [2, 3, 4, 5, 10, 20, 30, 40, 50, 100]
    nb_median_range = [2]

    # find out all the graphs classified to positive group 1.
    idx_dict = get_same_item_indices(y_all)
    Gn = [Gn[i] for i in idx_dict[1]]

    #    # compute Gram matrix.
    #    time0 = time.time()
    #    km = compute_kernel(Gn, gkernel, True)
    #    time_km = time.time() - time0
    #    # write Gram matrix to file.
    #    np.savez('results/gram_matrix_marg_itr10_pq0.03_mutag_positive.gm', gm=km, gmtime=time_km)

    time_list = []
    dis_ks_min_list = []
    sod_gs_list = []
    sod_gs_min_list = []
    nb_updated_list = []
    nb_updated_k_list = []
    g_best = []
    for nb_median in nb_median_range:
        print('\n-------------------------------------------------------')
        print('number of median graphs =', nb_median)
        random.seed(1)
        idx_rdm = random.sample(range(len(Gn)), nb_median)
        print('graphs chosen:', idx_rdm)
        Gn_median = [Gn[idx].copy() for idx in idx_rdm]

        #        for g in Gn_median:
        #            nx.draw(g, labels=nx.get_node_attributes(g, 'atom'), with_labels=True)
        ##            plt.savefig("results/preimage_mix/mutag.png", format="PNG")
        #            plt.show()
        #            plt.clf()

        ###################################################################
        gmfile = np.load(
            'results/gram_matrix_marg_itr10_pq0.03_mutag_positive.gm.npz')
        km_tmp = gmfile['gm']
        time_km = gmfile['gmtime']
        # modify mixed gram matrix.
        km = np.zeros((len(Gn) + nb_median, len(Gn) + nb_median))
        for i in range(len(Gn)):
            for j in range(i, len(Gn)):
                km[i, j] = km_tmp[i, j]
                km[j, i] = km[i, j]
        for i in range(len(Gn)):
            for j, idx in enumerate(idx_rdm):
                km[i, len(Gn) + j] = km[i, idx]
                km[len(Gn) + j, i] = km[i, idx]
        for i, idx1 in enumerate(idx_rdm):
            for j, idx2 in enumerate(idx_rdm):
                km[len(Gn) + i, len(Gn) + j] = km[idx1, idx2]

        ###################################################################
        alpha_range = [1 / nb_median] * nb_median
        time0 = time.time()
        dhat, ghat_list, dis_of_each_itr, nb_updated, nb_updated_k = \
            preimage_iam(Gn, Gn_median,
            alpha_range, range(len(Gn), len(Gn) + nb_median), km, k, r_max,
            gkernel, epsilon=epsilon, InitIAMWithAllDk=InitIAMWithAllDk,
            params_iam={'c_ei': c_ei, 'c_er': c_er, 'c_es': c_es,
                        'ite_max': ite_max_iam, 'epsilon': epsilon_iam,
                        'removeNodes': removeNodes, 'connected': connected_iam},
            params_ged=params_ged)

        time_total = time.time() - time0 + time_km
        print('\ntime: ', time_total)
        time_list.append(time_total)
        print('\nsmallest distance in kernel space: ', dhat)
        dis_ks_min_list.append(dhat)
        g_best.append(ghat_list)
        print('\nnumber of updates of the best graph: ', nb_updated)
        nb_updated_list.append(nb_updated)
        print('\nnumber of updates of k nearest graphs: ', nb_updated_k)
        nb_updated_k_list.append(nb_updated_k)

        # show the best graph and save it to file.
        print('the shortest distance is', dhat)
        print('one of the possible corresponding pre-images is')
        nx.draw(ghat_list[0],
                labels=nx.get_node_attributes(ghat_list[0], 'atom'),
                with_labels=True)
        plt.show()
        #        plt.savefig('results/preimage_iam/mutag_median_cs.001_nb' + str(nb_median) +
        #                    '.png', format="PNG")
        plt.clf()
        #        print(ghat_list[0].nodes(data=True))
        #        print(ghat_list[0].edges(data=True))

        # compute the corresponding sod in graph space.
        sod_tmp, _ = ged_median([ghat_list[0]],
                                Gn_median,
                                params_ged=params_ged)
        sod_gs_list.append(sod_tmp)
        sod_gs_min_list.append(np.min(sod_tmp))
        print('\nsmallest sod in graph space: ', np.min(sod_tmp))

    print('\nsods in graph space: ', sod_gs_list)
    print('\nsmallest sod in graph space for each set of median graphs: ',
          sod_gs_min_list)
    print(
        '\nsmallest distance in kernel space for each set of median graphs: ',
        dis_ks_min_list)
    print(
        '\nnumber of updates of the best graph for each set of median graphs by IAM: ',
        nb_updated_list)
    print(
        '\nnumber of updates of k nearest graphs for each set of median graphs by IAM: ',
        nb_updated_k_list)
    print('\ntimes:', time_list)