Beispiel #1
0
def motif_stats(graph_ref_list,
                graph_pred_list,
                motif_type='4cycle',
                ground_truth_match=None,
                bins=100):
    # graph motif counts (int for each graph)
    # normalized by graph size
    total_counts_ref = []
    total_counts_pred = []

    num_matches_ref = []
    num_matches_pred = []

    graph_pred_list_remove_empty = [
        G for G in graph_pred_list if not G.number_of_nodes() == 0
    ]
    indices = motif_to_indices[motif_type]
    for G in graph_ref_list:
        orbit_counts = orca(G)
        motif_counts = np.sum(orbit_counts[:, indices], axis=1)

        if ground_truth_match is not None:
            match_cnt = 0
            for elem in motif_counts:
                if elem == ground_truth_match:
                    match_cnt += 1
            num_matches_ref.append(match_cnt / G.number_of_nodes())

        #hist, _ = np.histogram(
        #        motif_counts, bins=bins, density=False)
        motif_temp = np.sum(motif_counts) / G.number_of_nodes()
        total_counts_ref.append(motif_temp)

    for G in graph_pred_list_remove_empty:
        orbit_counts = orca(G)
        motif_counts = np.sum(orbit_counts[:, indices], axis=1)

        if ground_truth_match is not None:
            match_cnt = 0
            for elem in motif_counts:
                if elem == ground_truth_match:
                    match_cnt += 1
            num_matches_pred.append(match_cnt / G.number_of_nodes())

        motif_temp = np.sum(motif_counts) / G.number_of_nodes()
        total_counts_pred.append(motif_temp)

    mmd_dist = compute_mmd(total_counts_ref,
                           total_counts_pred,
                           kernel=gaussian,
                           is_hist=False)
    #print('-------------------------')
    #print(np.sum(total_counts_ref) / len(total_counts_ref))
    #print('...')
    #print(np.sum(total_counts_pred) / len(total_counts_pred))
    #print('-------------------------')
    return mmd_dist
Beispiel #2
0
def orbit_stats_all(graph_ref_list, graph_pred_list):
    total_counts_ref = []
    total_counts_pred = []

    graph_pred_list_remove_empty = [
        G for G in graph_pred_list if not G.number_of_nodes() == 0
    ]

    for G in graph_ref_list:
        try:
            orbit_counts = orca(G)
        except:
            logger = get_logger()
            logger.error(
                'in orbit_stats_all, unable to run orca(G), graph ref list')
            logger.error(traceback.format_exc())
            continue
        orbit_counts_graph = np.sum(orbit_counts, axis=0) / G.number_of_nodes()
        total_counts_ref.append(orbit_counts_graph)

    for G in graph_pred_list:
        try:
            orbit_counts = orca(G)
        except:
            logger = get_logger()
            logger.error(
                'in orbit_stats_all, unable to run orca(G), graph pred list')
            logger.error(traceback.format_exc())
            continue
        orbit_counts_graph = np.sum(orbit_counts, axis=0) / G.number_of_nodes()
        total_counts_pred.append(orbit_counts_graph)

    total_counts_ref = np.array(total_counts_ref)
    total_counts_pred = np.array(total_counts_pred)

    # mmd_dist = compute_mmd(
    #     total_counts_ref,
    #     total_counts_pred,
    #     kernel=gaussian,
    #     is_hist=False,
    #     sigma=30.0)

    mmd_dist = compute_mmd(total_counts_ref,
                           total_counts_pred,
                           kernel=gaussian_tv,
                           is_hist=False,
                           sigma=30.0)

    # print('-------------------------')
    # print(np.sum(total_counts_ref, axis=0) / len(total_counts_ref))
    # print('...')
    # print(np.sum(total_counts_pred, axis=0) / len(total_counts_pred))
    # print('-------------------------')
    return mmd_dist
Beispiel #3
0
def clustering_stats(graph_ref_list,
                     graph_pred_list,
                     bins=100,
                     is_parallel=True):
    sample_ref = []
    sample_pred = []
    graph_pred_list_remove_empty = [
        G for G in graph_pred_list if not G.number_of_nodes() == 0
    ]

    prev = datetime.now()
    if is_parallel:
        with concurrent.futures.ThreadPoolExecutor() as executor:
            for clustering_hist in executor.map(clustering_worker,
                                                [(G, bins)
                                                 for G in graph_ref_list]):
                sample_ref.append(clustering_hist)
        with concurrent.futures.ThreadPoolExecutor() as executor:
            for clustering_hist in executor.map(
                    clustering_worker,
                [(G, bins) for G in graph_pred_list_remove_empty]):
                sample_pred.append(clustering_hist)

    else:
        for i in range(len(graph_ref_list)):
            clustering_coeffs_list = list(
                nx.clustering(graph_ref_list[i]).values())
            hist, _ = np.histogram(clustering_coeffs_list,
                                   bins=bins,
                                   range=(0.0, 1.0),
                                   density=False)
            sample_ref.append(hist)

        for i in range(len(graph_pred_list_remove_empty)):
            clustering_coeffs_list = list(
                nx.clustering(graph_pred_list_remove_empty[i]).values())
            hist, _ = np.histogram(clustering_coeffs_list,
                                   bins=bins,
                                   range=(0.0, 1.0),
                                   density=False)
            sample_pred.append(hist)

    mmd_dist = compute_mmd(sample_ref,
                           sample_pred,
                           kernel=gaussian_tv,
                           sigma=1.0 / 10)

    elapsed = datetime.now() - prev
    if PRINT_TIME:
        print('Time computing clustering mmd: ', elapsed)
    return mmd_dist
Beispiel #4
0
def degree_stats(graph_ref_list, graph_pred_list, is_parallel=True):
    ''' Compute the distance between the degree distributions of two unordered sets of graphs.
    Args:
      graph_ref_list, graph_target_list: two lists of networkx graphs to be evaluated
    '''
    sample_ref = []
    sample_pred = []
    # in case an empty graph is generated
    graph_pred_list_remove_empty = [
        G for G in graph_pred_list if not G.number_of_nodes() == 0
    ]

    prev = datetime.now()
    if is_parallel:
        with concurrent.futures.ThreadPoolExecutor() as executor:
            for deg_hist in executor.map(degree_worker, graph_ref_list):
                sample_ref.append(deg_hist)
        with concurrent.futures.ThreadPoolExecutor() as executor:
            for deg_hist in executor.map(degree_worker,
                                         graph_pred_list_remove_empty):
                sample_pred.append(deg_hist)

        # with concurrent.futures.ProcessPoolExecutor() as executor:
        #   for deg_hist in executor.map(degree_worker, graph_ref_list):
        #     sample_ref.append(deg_hist)
        # with concurrent.futures.ProcessPoolExecutor() as executor:
        #   for deg_hist in executor.map(degree_worker, graph_pred_list_remove_empty):
        #     sample_pred.append(deg_hist)
    else:
        for i in range(len(graph_ref_list)):
            degree_temp = np.array(nx.degree_histogram(graph_ref_list[i]))
            sample_ref.append(degree_temp)
        for i in range(len(graph_pred_list_remove_empty)):
            degree_temp = np.array(
                nx.degree_histogram(graph_pred_list_remove_empty[i]))
            sample_pred.append(degree_temp)
    # print(len(sample_ref), len(sample_pred))

    # mmd_dist = compute_mmd(sample_ref, sample_pred, kernel=gaussian_emd)
    # mmd_dist = compute_mmd(sample_ref, sample_pred, kernel=emd)
    # mmd_dist = compute_mmd(sample_ref, sample_pred, kernel=gaussian_tv)
    mmd_dist = compute_mmd(sample_ref, sample_pred, kernel=gaussian)

    elapsed = datetime.now() - prev
    if PRINT_TIME:
        print('Time computing degree mmd: ', elapsed)
    return mmd_dist
Beispiel #5
0
def orbit_stats_all(graph_ref_list, graph_pred_list):
    total_counts_ref = []
    total_counts_pred = []

    graph_pred_list_remove_empty = [
        G for G in graph_pred_list if not G.number_of_nodes() == 0
    ]

    for G in graph_ref_list:
        try:
            orbit_counts = orca(G)
        except:
            continue
        orbit_counts_graph = np.sum(orbit_counts, axis=0) / G.number_of_nodes()
        total_counts_ref.append(orbit_counts_graph)

    for G in graph_pred_list:
        try:
            orbit_counts = orca(G)
        except:
            continue
        orbit_counts_graph = np.sum(orbit_counts, axis=0) / G.number_of_nodes()
        total_counts_pred.append(orbit_counts_graph)

    total_counts_ref = np.array(total_counts_ref)
    total_counts_pred = np.array(total_counts_pred)

    # mmd_dist = compute_mmd(
    #     total_counts_ref,
    #     total_counts_pred,
    #     kernel=gaussian_tv,
    #     is_hist=False,
    #     sigma=30.0)

    mmd_dist = compute_mmd(total_counts_ref,
                           total_counts_pred,
                           kernel=gaussian,
                           is_hist=False,
                           sigma=30.0)

    # print('-------------------------')
    # print(np.sum(total_counts_ref, axis=0) / len(total_counts_ref))
    # print('...')
    # print(np.sum(total_counts_pred, axis=0) / len(total_counts_pred))
    # print('-------------------------')
    return mmd_dist
Beispiel #6
0
def max_degree_stats(graph_ref_list, graph_pred_list):
    # in case an empty graph is generated
    graph_pred_list_remove_empty = [
        G for G in graph_pred_list if not G.number_of_nodes() == 0
    ]

    sample_ref = [
        np.max(list(dict(nx.degree(G)).values())) for G in graph_ref_list
    ]
    sample_pred = [
        np.max(list(dict(nx.degree(G)).values()))
        for G in graph_pred_list_remove_empty
    ]

    hist_ref = np.bincount(sample_ref)
    hist_pred = np.bincount(sample_pred)

    mmd_dist = compute_mmd([hist_ref], [hist_pred], kernel=gaussian)
    return mmd_dist
Beispiel #7
0
def assortativity_stats(graph_ref_list, graph_pred_list):
    # in case an empty graph is generated
    graph_pred_list_remove_empty = [
        G for G in graph_pred_list if not G.number_of_nodes() == 0
    ]

    sample_ref = [
        nx.degree_assortativity_coefficient(G) for G in graph_ref_list
    ]
    sample_pred = [
        nx.degree_assortativity_coefficient(G)
        for G in graph_pred_list_remove_empty
    ]

    r = (min(min(sample_ref),
             min(sample_pred)), max(max(sample_ref), max(sample_pred)))
    hist_ref, _ = np.histogram(sample_ref, bins=100, range=r, density=False)
    hist_pred, _ = np.histogram(sample_pred, bins=100, range=r, density=False)

    mmd_dist = compute_mmd([hist_ref], [hist_pred], kernel=gaussian)
    return mmd_dist
Beispiel #8
0
def mean_degree_connectivity_stats(graph_ref_list, graph_pred_list):
    # in case an empty graph is generated
    graph_pred_list_remove_empty = [
        G for G in graph_pred_list if not G.number_of_nodes() == 0
    ]

    sample_ref = [
        np.mean(list(nx.average_degree_connectivity(G).values()))
        for G in graph_ref_list
    ]
    sample_pred = [
        np.mean(list(nx.average_degree_connectivity(G).values()))
        for G in graph_pred_list_remove_empty
    ]

    r = (min(min(sample_ref),
             min(sample_pred)), max(max(sample_ref), max(sample_pred)))
    hist_ref, _ = np.histogram(sample_ref, bins=100, range=r, density=False)
    hist_pred, _ = np.histogram(sample_pred, bins=100, range=r, density=False)

    mmd_dist = compute_mmd([hist_ref], [hist_pred], kernel=gaussian)
    return mmd_dist
Beispiel #9
0
    def test(self):
        self.config.save_dir = self.test_conf.test_model_dir

        ### Compute Erdos-Renyi baseline
        if self.config.test.is_test_ER:
            p_ER = sum([
                aa.number_of_edges() for aa in self.graphs_train
            ]) / sum([aa.number_of_nodes()**2 for aa in self.graphs_train])
            graphs_gen = [
                nx.fast_gnp_random_graph(self.max_num_nodes, p_ER, seed=ii)
                for ii in range(self.num_test_gen)
            ]
        else:
            ### load model
            model = eval(self.model_conf.name)(self.config)
            model_file = os.path.join(self.config.save_dir,
                                      self.test_conf.test_model_name)
            load_model(model, model_file, self.device)

            if self.use_gpu:
                model = nn.DataParallel(model,
                                        device_ids=self.gpus).to(self.device)

            model.eval()

            ### Generate Graphs
            A_pred = []
            num_nodes_pred = []
            alpha_list = []
            num_test_batch = int(
                np.ceil(self.num_test_gen / self.test_conf.batch_size))

            gen_run_time = []
            for ii in tqdm(range(num_test_batch)):
                with torch.no_grad():
                    start_time = time.time()
                    input_dict = {}
                    input_dict['is_sampling'] = True
                    input_dict['batch_size'] = self.test_conf.batch_size
                    input_dict['num_nodes_pmf'] = self.num_nodes_pmf_train
                    A_tmp, alpha_temp = model(input_dict)
                    gen_run_time += [time.time() - start_time]
                    A_pred += [aa.data.cpu().numpy() for aa in A_tmp]
                    num_nodes_pred += [aa.shape[0] for aa in A_tmp]
                    alpha_list += [aa.data.cpu().numpy() for aa in alpha_temp]

            logger.info('Average test time per mini-batch = {}'.format(
                np.mean(gen_run_time)))

            graphs_gen = [
                get_graph(aa, alpha_list[i]) for i, aa in enumerate(A_pred)
            ]

        ### Visualize Generated Graphs
        if self.is_vis:
            num_col = self.vis_num_row
            num_row = int(np.ceil(self.num_vis / num_col))
            test_epoch = self.test_conf.test_model_name
            test_epoch = test_epoch[test_epoch.rfind('_') +
                                    1:test_epoch.find('.pth')]
            save_name = os.path.join(
                self.config.save_dir,
                '{}_gen_graphs_epoch_{}_block_{}_stride_{}.png'.format(
                    self.config.test.test_model_name[:-4], test_epoch,
                    self.block_size, self.stride))

            # remove isolated nodes for better visulization
            graphs_pred_vis = [
                copy.deepcopy(gg) for gg in graphs_gen[:self.num_vis]
            ]

            # Saves Graphs
            for i, gg in enumerate(graphs_pred_vis):
                G = gg
                name = os.path.join(
                    self.config.save_dir,
                    '{}_gen_graphs_epoch_{}_{}.pickle'.format(
                        self.config.test.test_model_name[:-4], test_epoch, i))
                with open(name, 'wb') as handle:
                    pickle.dump(G, handle)

            if self.better_vis:
                for gg in graphs_pred_vis:
                    gg.remove_nodes_from(list(nx.isolates(gg)))

            # display the largest connected component for better visualization
            vis_graphs = []
            for gg in graphs_pred_vis:
                CGs = [gg.subgraph(c) for c in nx.connected_components(gg)]
                CGs = sorted(CGs,
                             key=lambda x: x.number_of_nodes(),
                             reverse=True)
                vis_graphs += [CGs[0]]

            if self.is_single_plot:
                draw_graph_list(vis_graphs,
                                num_row,
                                num_col,
                                fname=save_name,
                                layout='spring')
            else:
                draw_graph_list_separate(vis_graphs,
                                         fname=save_name[:-4],
                                         is_single=True,
                                         layout='spring')

            save_name = os.path.join(self.config.save_dir, 'train_graphs.png')

            if self.is_single_plot:
                draw_graph_list(self.graphs_train[:self.num_vis],
                                num_row,
                                num_col,
                                fname=save_name,
                                layout='spring')
            else:
                draw_graph_list_separate(self.graphs_train[:self.num_vis],
                                         fname=save_name[:-4],
                                         is_single=True,
                                         layout='spring')

        ### Evaluation
        if self.config.dataset.name in ['lobster']:
            acc = eval_acc_lobster_graph(graphs_gen)
            logger.info(
                'Validity accuracy of generated graphs = {}'.format(acc))

        num_nodes_gen = [len(aa) for aa in graphs_gen]

        # Compared with Validation Set
        num_nodes_dev = [len(gg.nodes)
                         for gg in self.graphs_dev]  # shape B X 1
        mmd_degree_dev, mmd_clustering_dev, mmd_4orbits_dev, mmd_spectral_dev = evaluate(
            self.graphs_dev, graphs_gen, degree_only=False)
        mmd_num_nodes_dev = compute_mmd([np.bincount(num_nodes_dev)],
                                        [np.bincount(num_nodes_gen)],
                                        kernel=gaussian_emd)

        # Compared with Test Set
        num_nodes_test = [len(gg.nodes)
                          for gg in self.graphs_test]  # shape B X 1
        mmd_degree_test, mmd_clustering_test, mmd_4orbits_test, mmd_spectral_test = evaluate(
            self.graphs_test, graphs_gen, degree_only=False)
        mmd_num_nodes_test = compute_mmd([np.bincount(num_nodes_test)],
                                         [np.bincount(num_nodes_gen)],
                                         kernel=gaussian_emd)

        logger.info(
            "Validation MMD scores of #nodes/degree/clustering/4orbits/spectral are = {}/{}/{}/{}/{}"
            .format(mmd_num_nodes_dev, mmd_degree_dev, mmd_clustering_dev,
                    mmd_4orbits_dev, mmd_spectral_dev))
        logger.info(
            "Test MMD scores of #nodes/degree/clustering/4orbits/spectral are = {}/{}/{}/{}/{}"
            .format(mmd_num_nodes_test, mmd_degree_test, mmd_clustering_test,
                    mmd_4orbits_test, mmd_spectral_test))

        if self.config.dataset.name in ['lobster']:
            return mmd_degree_dev, mmd_clustering_dev, mmd_4orbits_dev, mmd_spectral_dev, mmd_degree_test, mmd_clustering_test, mmd_4orbits_test, mmd_spectral_test, acc
        else:
            return mmd_degree_dev, mmd_clustering_dev, mmd_4orbits_dev, mmd_spectral_dev, mmd_degree_test, mmd_clustering_test, mmd_4orbits_test, mmd_spectral_test
Beispiel #10
0
def clustering_stats(graph_ref_list,
                     graph_pred_list,
                     bins=100,
                     is_parallel=True):
  sample_ref = []
  sample_pred = []
  graph_pred_list_remove_empty = [
      G for G in graph_pred_list if not G.number_of_nodes() == 0
  ]

  prev = datetime.now()
  if is_parallel:
    with concurrent.futures.ThreadPoolExecutor() as executor:
      for clustering_hist in executor.map(clustering_worker,
                                          [(G, bins) for G in graph_ref_list]):
        sample_ref.append(clustering_hist)
    with concurrent.futures.ThreadPoolExecutor() as executor:
      for clustering_hist in executor.map(
          clustering_worker, [(G, bins) for G in graph_pred_list_remove_empty]):
        sample_pred.append(clustering_hist)

    # with concurrent.futures.ProcessPoolExecutor() as executor:
    #   for clustering_hist in executor.map(clustering_worker,
    #                                       [(G, bins) for G in graph_ref_list]):
    #     sample_ref.append(clustering_hist)
    # with concurrent.futures.ProcessPoolExecutor() as executor:
    #   for clustering_hist in executor.map(
    #       clustering_worker, [(G, bins) for G in graph_pred_list_remove_empty]):
    #     sample_pred.append(clustering_hist)

    # check non-zero elements in hist
    #total = 0
    #for i in range(len(sample_pred)):
    #    nz = np.nonzero(sample_pred[i])[0].shape[0]
    #    total += nz
    #print(total)
  else:
    for i in range(len(graph_ref_list)):
      clustering_coeffs_list = list(nx.clustering(graph_ref_list[i]).values())
      hist, _ = np.histogram(
          clustering_coeffs_list, bins=bins, range=(0.0, 1.0), density=False)
      sample_ref.append(hist)

    for i in range(len(graph_pred_list_remove_empty)):
      clustering_coeffs_list = list(
          nx.clustering(graph_pred_list_remove_empty[i]).values())
      hist, _ = np.histogram(
          clustering_coeffs_list, bins=bins, range=(0.0, 1.0), density=False)
      sample_pred.append(hist)

  # mmd_dist = compute_mmd(
  #     sample_ref,
  #     sample_pred,
  #     kernel=gaussian_emd,
  #     sigma=1.0 / 10,
  #     distance_scaling=bins)

  mmd_dist = compute_mmd(
      sample_ref,
      sample_pred,
      kernel=gaussian_tv,
      sigma=1.0 / 10)

  elapsed = datetime.now() - prev
  if PRINT_TIME:
    print('Time computing clustering mmd: ', elapsed)
  return mmd_dist
Beispiel #11
0
    def test(self):
        self.config.save_dir = self.test_conf.test_model_dir

        ### Compute Erdos-Renyi baseline
        # if self.config.test.is_test_ER:
        p_ER = sum([aa.number_of_edges()
                    for aa in self.graphs_train]) / sum([aa.number_of_nodes()**2 for aa in self.graphs_train])
        # graphs_baseline = [nx.fast_gnp_random_graph(self.max_num_nodes, p_ER, seed=ii) for ii in range(self.num_test_gen)]
        graphs_gen = [nx.fast_gnp_random_graph(self.max_num_nodes, p_ER, seed=ii) for ii in range(self.num_test_gen)]
        temp = []
        for G in graphs_gen:
            G.remove_nodes_from(list(nx.isolates(G)))
            if G is not None:
                #  take the largest connected component
                CGs = [G.subgraph(c) for c in nx.connected_components(G)]
                CGs = sorted(CGs, key=lambda x: x.number_of_nodes(), reverse=True)
                temp.append(CGs[0])
        # graphs_gen = temp
        graphs_baseline = temp

        # else:
        ### load model
        model = eval(self.model_conf.name)(self.config)
        model_file = os.path.join(self.config.save_dir, self.test_conf.test_model_name)
        load_model(model, model_file, self.device)

        if self.use_gpu:
            model = nn.DataParallel(model, device_ids=self.gpus).to(self.device)

        model.eval()

        ### Generate Graphs
        A_pred = []
        node_label_pred = []
        num_nodes_pred = []
        num_test_batch = int(np.ceil(self.num_test_gen / self.test_conf.batch_size))

        gen_run_time = []
        for ii in tqdm(range(num_test_batch)):
            with torch.no_grad():
                start_time = time.time()
                input_dict = {}
                input_dict['is_sampling'] = True
                input_dict['batch_size'] = self.test_conf.batch_size
                input_dict['num_nodes_pmf'] = self.num_nodes_pmf_train
                A_tmp, node_label_tmp = model(input_dict)
                gen_run_time += [time.time() - start_time]
                A_pred += [aa.data.cpu().numpy() for aa in A_tmp]
                node_label_pred += [ll.data.cpu().numpy() for ll in node_label_tmp]
                num_nodes_pred += [aa.shape[0] for aa in A_tmp]
        # print(len(A_pred), type(A_pred[0]))

        logger.info('Average test time per mini-batch = {}'.format(np.mean(gen_run_time)))

        # print(A_pred[0].shape,
        #       get_graph(A_pred[0]).number_of_nodes(),
        #       get_graph_with_labels(A_pred[0], node_label_pred[0]).number_of_nodes())
        # print(A_pred[0])
        # return
        # graphs_gen = [get_graph(aa) for aa in A_pred]
        graphs_gen = [get_graph_with_labels(aa, ll) for aa, ll in zip(A_pred, node_label_pred)]
        valid_pctg, bipartite_pctg = calculate_validity(graphs_gen)  # for adding bipartite graph attribute

        # return

        ### Visualize Generated Graphs
        if self.is_vis:
            num_col = self.vis_num_row
            num_row = int(np.ceil(self.num_vis / num_col))
            test_epoch = self.test_conf.test_model_name
            test_epoch = test_epoch[test_epoch.rfind('_') + 1:test_epoch.find('.pth')]
            save_name = os.path.join(self.config.save_dir, '{}_gen_graphs_epoch_{}_block_{}_stride_{}.png'.format(
                self.config.test.test_model_name[:-4], test_epoch, self.block_size, self.stride))

            # remove isolated nodes for better visulization
            # graphs_pred_vis = [copy.deepcopy(gg) for gg in graphs_gen[:self.num_vis]]
            graphs_pred_vis = [copy.deepcopy(gg) for gg in graphs_gen if not gg.graph['bipartite']]
            logger.info('Number of not bipartite graphs: {} / {}'.format(len(graphs_pred_vis), len(graphs_gen)))
            # if self.better_vis:
            #     for gg in graphs_pred_vis:
            #         gg.remove_nodes_from(list(nx.isolates(gg)))

            # # display the largest connected component for better visualization
            # vis_graphs = []
            # for gg in graphs_pred_vis:
            #     CGs = [gg.subgraph(c) for c in nx.connected_components(gg)] # nx.subgraph makes a graph frozen!
            #     CGs = sorted(CGs, key=lambda x: x.number_of_nodes(), reverse=True)
            #     vis_graphs += [CGs[0]]
            vis_graphs = graphs_pred_vis

            if self.is_single_plot:
                draw_graph_list(vis_graphs, num_row, num_col, fname=save_name, layout='spring')
            else:  #XD: using this for now
                draw_graph_list_separate(vis_graphs, fname=save_name[:-4], is_single=True, layout='spring')

            save_name = os.path.join(self.config.save_dir, 'train_graphs.png')
            if self.is_single_plot:
                draw_graph_list(self.graphs_train[:self.num_vis], num_row, num_col, fname=save_name, layout='spring')
                print('training single plot saved at:', save_name)
            else:  #XD: using this for now
                graph_list_train = [get_graph_from_nx(G) for G in self.graphs_train[:self.num_vis]]
                draw_graph_list_separate(graph_list_train, fname=save_name[:-4], is_single=True, layout='spring')
                print('training plots saved individually at:', save_name[:-4])
        return

        ### Evaluation
        if self.config.dataset.name in ['lobster']:
            acc = eval_acc_lobster_graph(graphs_gen)
            logger.info('Validity accuracy of generated graphs = {}'.format(acc))
        '''=====XD====='''
        ## graphs_gen = [generate_random_baseline_single(len(aa)) for aa in graphs_gen]  # use this line for random baseline MMD scores. Remember to comment it later!
        # draw_hists(self.graphs_test, graphs_baseline, graphs_gen)
        valid_pctg, bipartite_pctg = calculate_validity(graphs_gen)
        # logger.info('Generated {} graphs, valid percentage = {:.2f}, bipartite percentage = {:.2f}'.format(
        #     len(graphs_gen), valid_pctg, bipartite_pctg))
        # # return
        '''=====XD====='''

        num_nodes_gen = [len(aa) for aa in graphs_gen]

        # # Compared with Validation Set
        # num_nodes_dev = [len(gg.nodes) for gg in self.graphs_dev]  # shape B X 1
        # mmd_degree_dev, mmd_clustering_dev, mmd_4orbits_dev, mmd_spectral_dev, mmd_mean_degree_dev, mmd_max_degree_dev, mmd_mean_centrality_dev, mmd_assortativity_dev, mmd_mean_degree_connectivity_dev = evaluate(self.graphs_dev, graphs_gen, degree_only=False)
        # mmd_num_nodes_dev = compute_mmd([np.bincount(num_nodes_dev)], [np.bincount(num_nodes_gen)], kernel=gaussian_emd)
        # logger.info("Validation MMD scores of #nodes/degree/clustering/4orbits/spectral/... are = {:.5f}/{:.5f}/{:.5f}/{:.5f}/{:.5f}/{:.5f}/{:.5f}/{:.5f}/{:.5f}/{:.5f}".format(mmd_num_nodes_dev, mmd_degree_dev, mmd_clustering_dev, mmd_4orbits_dev, mmd_spectral_dev, mmd_mean_degree_dev, mmd_max_degree_dev, mmd_mean_centrality_dev, mmd_assortativity_dev, mmd_mean_degree_connectivity_dev))

        # Compared with Test Set
        num_nodes_test = [len(gg.nodes) for gg in self.graphs_test]  # shape B X 1
        mmd_degree_test, mmd_clustering_test, mmd_4orbits_test, mmd_spectral_test, mmd_mean_degree_test, mmd_max_degree_test, mmd_mean_centrality_test, mmd_assortativity_test, mmd_mean_degree_connectivity_test = evaluate(
            self.graphs_test, graphs_gen, degree_only=False)
        mmd_num_nodes_test = compute_mmd(
            [np.bincount(num_nodes_test)], [np.bincount(num_nodes_gen)], kernel=gaussian_emd)

        logger.info(
            "Test MMD scores of #nodes/degree/clustering/4orbits/spectral/... are = {:.5f}/{:.5f}/{:.5f}/{:.5f}/{:.5f}/{:.5f}/{:.5f}/{:.5f}/{:.5f}/{:.5f}".
            format(mmd_num_nodes_test, mmd_degree_test, mmd_clustering_test, mmd_4orbits_test, mmd_spectral_test,
                   mmd_mean_degree_test, mmd_max_degree_test, mmd_mean_centrality_test, mmd_assortativity_test,
                   mmd_mean_degree_connectivity_test))
  def test(self):
    self.config.save_dir_train = self.test_conf.test_model_dir

    if not self.config.test.is_test_ER:
      ### load model
      model = eval(self.model_conf.name)(self.config)
      model_file = os.path.join(self.config.save_dir_train, self.test_conf.test_model_name)
      load_model(model, model_file, self.device)
      if self.use_gpu:
        model = nn.DataParallel(model, device_ids=self.gpus).to(self.device)
      model.eval()

      if hasattr(self.config, 'complete_graph_model'):
        complete_graph_model = eval(self.config.complete_graph_model.name)(self.config.complete_graph_model)
        complete_graph_model_file = os.path.join(self.config.complete_graph_model.test_model_dir,
                                                 self.config.complete_graph_model.test_model_name)
        load_model(complete_graph_model, complete_graph_model_file, self.device)
        if self.use_gpu:
          complete_graph_model = nn.DataParallel(complete_graph_model, device_ids=self.gpus).to(self.device)
        complete_graph_model.eval()

    if self.config.test.is_test_ER or not hasattr(self.config.test, 'hard_multi') or not self.config.test.hard_multi:
      hard_thre_list = [None]
    else:
      hard_thre_list = np.arange(0.5, 1, 0.1)

    for test_hard_idx, hard_thre in enumerate(hard_thre_list):
      if self.config.test.is_test_ER:
        ### Compute Erdos-Renyi baseline
        p_ER = sum([aa.number_of_edges() for aa in self.graphs_train]) / sum([aa.number_of_nodes() ** 2 for aa in self.graphs_train])
        graphs_gen = [nx.fast_gnp_random_graph(self.max_num_nodes, p_ER, seed=ii) for ii in range(self.num_test_gen)]
      else:
        logger.info('Test pass {}. Hard threshold {}'.format(test_hard_idx, hard_thre))
        ### Generate Graphs
        A_pred = []
        num_nodes_pred = []
        num_test_batch = int(np.ceil(self.num_test_gen / self.test_conf.batch_size))

        gen_run_time = []
        for ii in tqdm(range(num_test_batch)):
          with torch.no_grad():
            start_time = time.time()
            input_dict = {}
            input_dict['is_sampling'] = True
            input_dict['batch_size'] = self.test_conf.batch_size
            input_dict['num_nodes_pmf'] = self.num_nodes_pmf_train
            input_dict['hard_thre'] = hard_thre
            A_tmp = model(input_dict)

            if hasattr(self.config, 'complete_graph_model'):
              final_A_list = []
              for batch_idx in range(len(A_tmp)):
                new_pmf = torch.zeros(len(self.num_nodes_pmf_train))
                max_prob = 0.
                max_prob_num_nodes = None
                for num_nodes, prob in enumerate(self.num_nodes_pmf_train):
                  if prob == 0.:
                    continue
                  tmp_data = {}
                  A_tmp_tmp = A_tmp[batch_idx][:num_nodes, :num_nodes]
                  tmp_data['adj'] = F.pad(
                    A_tmp_tmp,
                    (0, self.config.complete_graph_model.model.max_num_nodes-num_nodes, 0, 0),
                    'constant', value=.0)[None, None, ...]

                  adj = torch.tril(A_tmp_tmp, diagonal=-1)
                  adj = adj + adj.transpose(0, 1)
                  edges = adj.to_sparse().coalesce().indices()
                  tmp_data['edges'] = edges.t()
                  tmp_data['subgraph_idx'] = torch.zeros(num_nodes).long().to(self.device, non_blocking=True)

                  tmp_logit = complete_graph_model(tmp_data)
                  new_pmf[num_nodes] = torch.sigmoid(tmp_logit).item()

                  if new_pmf[num_nodes] > max_prob:
                    max_prob = new_pmf[num_nodes]
                    max_prob_num_nodes = num_nodes

                  if new_pmf[num_nodes] <= 0.9:
                    new_pmf[num_nodes] = 0.

                if (new_pmf == 0.).all():
                  logger.info('(new_pmf == 0.).all(), use {} nodes with max prob {}'.format(max_prob_num_nodes, max_prob))
                  final_num_nodes = max_prob_num_nodes
                else:
                  final_num_nodes = torch.multinomial(new_pmf, 1).item()
                final_A_list.append(
                  A_tmp_tmp[:final_num_nodes, :final_num_nodes]
                )
              A_tmp = final_A_list
            gen_run_time += [time.time() - start_time]
            A_pred += [aa.cpu().numpy() for aa in A_tmp]
            num_nodes_pred += [aa.shape[0] for aa in A_tmp]
        print('num_nodes_pred', num_nodes_pred)
        logger.info('Average test time per mini-batch = {}'.format(
          np.mean(gen_run_time)))

        graphs_gen = [get_graph(aa) for aa in A_pred]

      ### Visualize Generated Graphs
      if self.is_vis:
        num_col = self.vis_num_row
        num_row = self.num_vis // num_col
        test_epoch = self.test_conf.test_model_name
        test_epoch = test_epoch[test_epoch.rfind('_') + 1:test_epoch.find('.pth')]
        if hard_thre is not None:
          save_name = os.path.join(self.config.save_dir_train, '{}_gen_graphs_epoch_{}_hard_{}.png'.format(
            self.config.test.test_model_name[:-4], test_epoch,
            int(round(hard_thre*10))))
          save_name2 = os.path.join(self.config.save_dir,
                                   '{}_gen_graphs_epoch_{}_hard_{}.png'.format(
                                     self.config.test.test_model_name[:-4], test_epoch,
                                     int(round(hard_thre * 10))))
        else:
          save_name = os.path.join(self.config.save_dir_train,
                                   '{}_gen_graphs_epoch_{}.png'.format(
                                     self.config.test.test_model_name[:-4], test_epoch))
          save_name2 = os.path.join(self.config.save_dir,
                                    '{}_gen_graphs_epoch_{}.png'.format(
                                      self.config.test.test_model_name[:-4], test_epoch))

        # remove isolated nodes for better visulization
        graphs_pred_vis = [copy.deepcopy(gg) for gg in graphs_gen[:self.num_vis]]

        if self.better_vis:
          # actually not necessary with the following largest connected component selection
          for gg in graphs_pred_vis:
            gg.remove_nodes_from(list(nx.isolates(gg)))

        # display the largest connected component for better visualization
        vis_graphs = []
        for gg in graphs_pred_vis:
          if self.better_vis:
            CGs = [gg.subgraph(c) for c in nx.connected_components(gg)]
            CGs = sorted(CGs, key=lambda x: x.number_of_nodes(), reverse=True)
            vis_graphs += [CGs[0]]
          else:
            vis_graphs += [gg]
        print('number of nodes after better vis', [tmp_g.number_of_nodes() for tmp_g in vis_graphs])

        if self.is_single_plot:
          # draw_graph_list(vis_graphs, num_row, num_col, fname=save_name, layout='spring')
          draw_graph_list(vis_graphs, num_row, num_col, fname=save_name2, layout='spring')
        else:
          # draw_graph_list_separate(vis_graphs, fname=save_name[:-4], is_single=True, layout='spring')
          draw_graph_list_separate(vis_graphs, fname=save_name2[:-4], is_single=True, layout='spring')

        if test_hard_idx == 0:
          save_name = os.path.join(self.config.save_dir_train, 'train_graphs.png')

          if self.is_single_plot:
            draw_graph_list(
              self.graphs_train[:self.num_vis],
              num_row,
              num_col,
              fname=save_name,
              layout='spring')
          else:
            draw_graph_list_separate(
              self.graphs_train[:self.num_vis],
              fname=save_name[:-4],
              is_single=True,
              layout='spring')

      ### Evaluation
      if self.config.dataset.name in ['lobster']:
        acc = eval_acc_lobster_graph(graphs_gen)
        logger.info('Validity accuracy of generated graphs = {}'.format(acc))

      num_nodes_gen = [len(aa) for aa in graphs_gen]

      # Compared with Validation Set
      num_nodes_dev = [len(gg.nodes) for gg in self.graphs_dev]  # shape B X 1
      mmd_degree_dev, mmd_clustering_dev, mmd_4orbits_dev, mmd_spectral_dev = evaluate(self.graphs_dev, graphs_gen,
                                                                                       degree_only=False)
      mmd_num_nodes_dev = compute_mmd([np.bincount(num_nodes_dev)], [np.bincount(num_nodes_gen)], kernel=gaussian_emd)

      # Compared with Test Set
      num_nodes_test = [len(gg.nodes) for gg in self.graphs_test]  # shape B X 1
      mmd_degree_test, mmd_clustering_test, mmd_4orbits_test, mmd_spectral_test = evaluate(self.graphs_test, graphs_gen,
                                                                                           degree_only=False)
      mmd_num_nodes_test = compute_mmd([np.bincount(num_nodes_test)], [np.bincount(num_nodes_gen)], kernel=gaussian_emd)

      logger.info(
        "Validation MMD scores of #nodes/degree/clustering/4orbits/spectral are = {:.4E}/{:.4E}/{:.4E}/{:.4E}/{:.4E}".format(Decimal(mmd_num_nodes_dev),
                                                                                                   Decimal(mmd_degree_dev),
                                                                                                   Decimal(mmd_clustering_dev),
                                                                                                   Decimal(mmd_4orbits_dev),
                                                                                                   Decimal(mmd_spectral_dev)))
      logger.info(
        "Test MMD scores of #nodes/degree/clustering/4orbits/spectral are = {:.4E}/{:.4E}/{:.4E}/{:.4E}/{:.4E}".format(Decimal(mmd_num_nodes_test),
                                                                                                   Decimal(mmd_degree_test),
                                                                                                   Decimal(mmd_clustering_test),
                                                                                                   Decimal(mmd_4orbits_test),
                                                                                                   Decimal(mmd_spectral_test)))
Beispiel #13
0
    def test(self):
        self.config.save_dir_train = self.test_conf.test_model_dir

        ### test dataset
        test_dataset = eval(self.dataset_conf.loader_name)(self.config,
                                                           self.graphs_test,
                                                           tag='test')
        test_loader = torch.utils.data.DataLoader(
            test_dataset,
            batch_size=self.test_conf.batch_size,
            shuffle=False,
            num_workers=self.train_conf.num_workers,
            collate_fn=test_dataset.collate_fn,
            drop_last=False)

        ### load model
        args = self.config.model
        n_labels = self.dataset_conf.max_m + self.dataset_conf.max_n
        G = define_G(args.nz, args.ngf, args.netG, args.final_activation,
                     args.norm_G)
        model_file_G = os.path.join(self.config.save_dir_train,
                                    self.test_conf.test_model_name)

        load_model(G, model_file_G, self.device)
        if self.use_gpu:
            G = G.cuda()  #nn.DataParallel(G).to(self.device)
        G.train()

        if not hasattr(self.config.test,
                       'hard_multi') or not self.config.test.hard_multi:
            hard_thre_list = [None]
        else:
            hard_thre_list = np.arange(0.5, 1, 0.1)

        for test_hard_idx, hard_thre in enumerate(hard_thre_list):
            logger.info('Test pass {}. Hard threshold {}'.format(
                test_hard_idx, hard_thre))
            ### Generate Graphs
            A_pred = []
            gen_run_time = []

            for batch_data in test_loader:
                # asserted in arg helper
                ff = 0

                with torch.no_grad():
                    data = {}
                    data['adj'] = batch_data[ff]['adj'].pin_memory().to(
                        self.config.device, non_blocking=True)
                    data['m'] = batch_data[ff]['m'].to(self.config.device,
                                                       non_blocking=True)
                    data['n'] = batch_data[ff]['n'].to(self.config.device,
                                                       non_blocking=True)

                    batch_size = data['adj'].size(0)

                    i_onehot = torch.zeros(
                        (batch_size, self.dataset_conf.max_m),
                        requires_grad=True).pin_memory().to(self.config.device,
                                                            non_blocking=True)
                    i_onehot.scatter_(1, data['m'][:, None] - 1, 1)
                    j_onehot = torch.zeros(
                        (batch_size, self.dataset_conf.max_n),
                        requires_grad=True).pin_memory().to(self.config.device,
                                                            non_blocking=True)
                    j_onehot.scatter_(1, data['n'][:, None] - 1, 1)
                    y_onehot = torch.cat((i_onehot, j_onehot), dim=1)

                    if args.nz > n_labels:
                        noise = torch.randn(
                            (batch_size, args.nz - n_labels, 1, 1),
                            requires_grad=True).to(self.config.device,
                                                   non_blocking=True)
                        z_input = torch.cat(
                            (y_onehot.view(batch_size, n_labels, 1, 1), noise),
                            dim=1)
                    else:
                        z_input = y_onehot.view(batch_size, n_labels, 1, 1)

                    start_time = time.time()
                    output = G(z_input).squeeze(1)  # (B, 1, n, n)
                    if self.model_conf.final_activation == 'tanh':
                        output = (output + 1) / 2
                    if self.model_conf.is_sym:
                        output = torch.tril(output, diagonal=-1)
                        output = output + output.transpose(1, 2)
                    gen_run_time += [time.time() - start_time]

                    if hard_thre is not None:
                        A_pred += [(output[batch_idx, ...] >
                                    hard_thre).long().cpu().numpy()
                                   for batch_idx in range(batch_size)]
                    else:
                        A_pred += [
                            torch.bernoulli(output[batch_idx,
                                                   ...]).long().cpu().numpy()
                            for batch_idx in range(batch_size)
                        ]

            logger.info('Average test time per mini-batch = {}'.format(
                np.mean(gen_run_time)))

            graphs_gen = [get_graph(aa) for aa in A_pred]

            ### Visualize Generated Graphs
            if self.is_vis:
                num_col = self.vis_num_row
                num_row = self.num_vis // num_col
                test_epoch = self.test_conf.test_model_name
                test_epoch = test_epoch[test_epoch.rfind('_') +
                                        1:test_epoch.find('.pth')]
                if hard_thre is not None:
                    save_name = os.path.join(
                        self.config.save_dir_train,
                        '{}_gen_graphs_epoch_{}_hard_{}.png'.format(
                            self.config.test.test_model_name[:-4], test_epoch,
                            int(round(hard_thre * 10))))
                    save_name2 = os.path.join(
                        self.config.save_dir,
                        '{}_gen_graphs_epoch_{}_hard_{}.png'.format(
                            self.config.test.test_model_name[:-4], test_epoch,
                            int(round(hard_thre * 10))))
                else:
                    save_name = os.path.join(
                        self.config.save_dir_train,
                        '{}_gen_graphs_epoch_{}.png'.format(
                            self.config.test.test_model_name[:-4], test_epoch))
                    save_name2 = os.path.join(
                        self.config.save_dir,
                        '{}_gen_graphs_epoch_{}.png'.format(
                            self.config.test.test_model_name[:-4], test_epoch))

                # remove isolated nodes for better visulization
                graphs_pred_vis = [
                    copy.deepcopy(gg) for gg in graphs_gen[:self.num_vis]
                ]

                if self.better_vis:
                    # actually not necessary with the following largest connected component selection
                    for gg in graphs_pred_vis:
                        gg.remove_nodes_from(list(nx.isolates(gg)))

                # display the largest connected component for better visualization
                vis_graphs = []
                for gg in graphs_pred_vis:
                    if self.better_vis:
                        CGs = [
                            gg.subgraph(c) for c in nx.connected_components(gg)
                        ]
                        CGs = sorted(CGs,
                                     key=lambda x: x.number_of_nodes(),
                                     reverse=True)
                        vis_graphs += [CGs[0]]
                    else:
                        vis_graphs += [gg]
                print('number of nodes after better vis',
                      [tmp_g.number_of_nodes() for tmp_g in vis_graphs])

                if self.is_single_plot:
                    # draw_graph_list(vis_graphs, num_row, num_col, fname=save_name, layout='spring')
                    draw_graph_list(vis_graphs,
                                    num_row,
                                    num_col,
                                    fname=save_name2,
                                    layout='spring')
                else:
                    # draw_graph_list_separate(vis_graphs, fname=save_name[:-4], is_single=True, layout='spring')
                    draw_graph_list_separate(vis_graphs,
                                             fname=save_name2[:-4],
                                             is_single=True,
                                             layout='spring')

                if test_hard_idx == 0:
                    save_name = os.path.join(self.config.save_dir_train,
                                             'train_graphs.png')

                    if self.is_single_plot:
                        draw_graph_list(self.graphs_train[:self.num_vis],
                                        num_row,
                                        num_col,
                                        fname=save_name,
                                        layout='spring')
                    else:
                        draw_graph_list_separate(
                            self.graphs_train[:self.num_vis],
                            fname=save_name[:-4],
                            is_single=True,
                            layout='spring')

            ### Evaluation
            if self.config.dataset.name in ['lobster']:
                acc = eval_acc_lobster_graph(graphs_gen)
                logger.info(
                    'Validity accuracy of generated graphs = {}'.format(acc))

            num_nodes_gen = [len(aa) for aa in graphs_gen]

            # Compared with Validation Set
            num_nodes_dev = [len(gg.nodes)
                             for gg in self.graphs_dev]  # shape B X 1
            mmd_degree_dev, mmd_clustering_dev, mmd_4orbits_dev, mmd_spectral_dev = evaluate(
                self.graphs_dev, graphs_gen, degree_only=False)
            mmd_num_nodes_dev = compute_mmd([np.bincount(num_nodes_dev)],
                                            [np.bincount(num_nodes_gen)],
                                            kernel=gaussian_emd)

            # Compared with Test Set
            num_nodes_test = [len(gg.nodes)
                              for gg in self.graphs_test]  # shape B X 1
            mmd_degree_test, mmd_clustering_test, mmd_4orbits_test, mmd_spectral_test = evaluate(
                self.graphs_test, graphs_gen, degree_only=False)
            mmd_num_nodes_test = compute_mmd([np.bincount(num_nodes_test)],
                                             [np.bincount(num_nodes_gen)],
                                             kernel=gaussian_emd)

            logger.info(
                "Validation MMD scores of #nodes/degree/clustering/4orbits/spectral are = {:.4E}/{:.4E}/{:.4E}/{:.4E}/{:.4E}"
                .format(Decimal(mmd_num_nodes_dev), Decimal(mmd_degree_dev),
                        Decimal(mmd_clustering_dev), Decimal(mmd_4orbits_dev),
                        Decimal(mmd_spectral_dev)))
            logger.info(
                "Test MMD scores of #nodes/degree/clustering/4orbits/spectral are = {:.4E}/{:.4E}/{:.4E}/{:.4E}/{:.4E}"
                .format(Decimal(mmd_num_nodes_test), Decimal(mmd_degree_test),
                        Decimal(mmd_clustering_test),
                        Decimal(mmd_4orbits_test), Decimal(mmd_spectral_test)))
Beispiel #14
0
def test(args, config, model, dataset):
    config.save_dir = config.test.test_model_dir
    ### Compute Erdos-Renyi baseline
    if config.test.is_test_ER:
      p_ER = sum([aa.number_of_edges() for aa in dataset.graphs_train]) / sum([aa.number_of_nodes() ** 2 for aa in dataset.graphs_train])
      graphs_gen = [nx.fast_gnp_random_graph(config.model.max_num_nodes, p_ER,
                                             seed=ii) for ii in
                    range(config.test.num_test_gen)]
    else:
      ### load model
      model_file = os.path.join(config.save_dir, config.test.test_model_name)
      load_model(model, model_file, args.dev)

      model.eval()

      ### Generate Graphs
      A_pred = []
      num_nodes_pred = []
      num_test_batch = int(np.ceil(config.test.num_test_gen / config.test.batch_size))

      gen_run_time = []
      for ii in tqdm(range(num_test_batch)):
        with torch.no_grad():
          start_time = time.time()
          input_dict = {}
          input_dict['is_sampling']=True
          input_dict['batch_size']=config.test.batch_size
          input_dict['num_nodes_pmf']=dataset.train_dataset.num_nodes_pmf_train
          A_tmp = model(input_dict)
          gen_run_time += [time.time() - start_time]
          A_pred += [aa.data.cpu().numpy() for aa in A_tmp]
          num_nodes_pred += [aa.shape[0] for aa in A_tmp]

      print('Average test time per mini-batch = {}'.format(
          np.mean(gen_run_time)))

      graphs_gen = [get_graph(aa) for aa in A_pred]

    ### Visualize Generated Graphs
    if config.test.is_vis:
      num_col = config.test.vis_num_row
      num_row = int(np.ceil(config.test.num_vis / num_col))
      test_epoch = config.test.test_model_name
      test_epoch = test_epoch[test_epoch.rfind('_') + 1:test_epoch.find('.pth')]
      plot_dir = config.save_dir + '/plots'
      if not os.path.exists(plot_dir):
          os.mkdir(plot_dir)
      save_name = os.path.join(plot_dir,
                               '{}_gen_graphs_epoch_{}_block_{}_stride_{}.png'.format(config.test.test_model_name[:-4],
                                                                                      test_epoch,
                                                                                      config.model.block_size,
                                                                                      config.model.sample_stride))

      # remove isolated nodes for better visulization
      graphs_pred_vis = [copy.deepcopy(gg) for gg in graphs_gen[:config.test.num_vis]]

      if config.test.better_vis:
        for gg in graphs_pred_vis:
          gg.remove_nodes_from(list(nx.isolates(gg)))

      # display the largest connected component for better visualization
      vis_graphs = []
      for gg in graphs_pred_vis:
        CGs = [gg.subgraph(c) for c in nx.connected_components(gg)]
        CGs = sorted(CGs, key=lambda x: x.number_of_nodes(), reverse=True)
        vis_graphs += [CGs[0]]

      if config.test.is_single_plot:
        draw_graph_list(vis_graphs, num_row, num_col, fname=save_name, layout='spring')
      else:
        draw_graph_list_separate(vis_graphs, fname=save_name[:-4], is_single=True, layout='spring')

      save_name = os.path.join(plot_dir, 'train_graphs.png')

      if config.test.is_single_plot:
        draw_graph_list(
            dataset.graphs_train[:config.test.num_vis],
            num_row,
            num_col,
            fname=save_name,
            layout='spring')
      else:
        draw_graph_list_separate(
            dataset.graphs_train[:config.test.num_vis],
            fname=save_name[:-4],
            is_single=True,
            layout='spring')

    ### Evaluation
    if config.dataset.name in ['lobster']:
      acc = eval_acc_lobster_graph(graphs_gen)
      print('Validity accuracy of generated graphs = {}'.format(acc))

    num_nodes_gen = [len(aa) for aa in graphs_gen]

    # Compared with Validation Set
    num_nodes_dev = [len(gg.nodes) for gg in dataset.graphs_dev]  # shape B X 1
    mmd_degree_dev, mmd_clustering_dev, mmd_4orbits_dev, mmd_spectral_dev, dev_acc = evaluate_generated(dataset.graphs_dev, graphs_gen, degree_only=False)
    mmd_num_nodes_dev = compute_mmd([np.bincount(num_nodes_dev)], [np.bincount(num_nodes_gen)], kernel=gaussian_emd)

    # Compared with Test Set
    num_nodes_test = [len(gg.nodes) for gg in dataset.graphs_test]  # shape B X 1
    mmd_degree_test, mmd_clustering_test, mmd_4orbits_test, mmd_spectral_test, test_acc= evaluate_generated(dataset.graphs_test, graphs_gen, degree_only=False)
    mmd_num_nodes_test = compute_mmd([np.bincount(num_nodes_test)], [np.bincount(num_nodes_gen)], kernel=gaussian_emd)

    print("Validation MMD scores of #nodes/degree/clustering/4orbits/spectral are = {}/{}/{}/{}/{}".format(mmd_num_nodes_dev, mmd_degree_dev, mmd_clustering_dev, mmd_4orbits_dev, mmd_spectral_dev))
    print("Test MMD scores of #nodes/degree/clustering/4orbits/spectral are = {}/{}/{}/{}/{}".format(mmd_num_nodes_test, mmd_degree_test, mmd_clustering_test, mmd_4orbits_test, mmd_spectral_test))

    if config.dataset.name in ['lobster']:
      return mmd_degree_dev, mmd_clustering_dev, mmd_4orbits_dev, mmd_spectral_dev, mmd_degree_test, mmd_clustering_test, mmd_4orbits_test, mmd_spectral_test, acc
    else:
      return mmd_degree_dev, mmd_clustering_dev, mmd_4orbits_dev, mmd_spectral_dev, mmd_degree_test, mmd_clustering_test, mmd_4orbits_test, mmd_spectral_test
Beispiel #15
0
    def test(self):
        self.config.save_dir = self.test_conf.test_model_dir

        ### Compute Erdos-Renyi baseline
        if self.config.test.is_test_ER:
            p_ER = sum([
                aa.number_of_edges() for aa in self.graphs_train
            ]) / sum([aa.number_of_nodes()**2 for aa in self.graphs_train])
            graphs_gen = [
                nx.fast_gnp_random_graph(self.max_num_nodes, p_ER, seed=ii)
                for ii in range(self.num_test_gen)
            ]
        else:
            ### load model
            model = eval(self.model_conf.name)(self.config)
            model_file = os.path.join(self.config.save_dir,
                                      self.test_conf.test_model_name)
            load_model(model, model_file, self.device)

            # create graph classifier
            graph_classifier = GraphSAGE(3, 2, 3, 32, 'add')
            # graph_classifier = DiffPool(3, 2, max_num_nodes=630)
            # graph_classifier = DGCNN(3, 2, 'PROTEINS_full')
            graph_classifier.load_state_dict(
                torch.load('output/MODEL_PROTEINS.pkl'))

            if self.use_gpu:
                model = nn.DataParallel(model,
                                        device_ids=self.gpus).to(self.device)
                graph_classifier = graph_classifier.to(self.device)

            model.eval()
            graph_classifier.eval()

            ### Generate Graphs
            A_pred = []
            num_nodes_pred = []
            num_test_batch = 5000

            gen_run_time = []
            graph_acc_count = 0
            # for ii in tqdm(range(1)):
            #     with torch.no_grad():
            #         start_time = time.time()
            #         input_dict = {}
            #         input_dict['is_sampling'] = True
            #         input_dict['batch_size'] = self.test_conf.batch_size
            #         input_dict['num_nodes_pmf'] = self.num_nodes_pmf_train
            #         A_tmp, label_tmp = model(input_dict)
            #         gen_run_time += [time.time() - start_time]
            #         A_pred += [aa.data.cpu().numpy() for aa in A_tmp]
            #         num_nodes_pred += [aa.shape[0] for aa in A_tmp]

            from classifier.losses import MulticlassClassificationLoss
            classifier_loss = MulticlassClassificationLoss()

            ps = []

            acc_count_by_label = {0: 0, 1: 0}

            graph_acc_count = 0
            for ii in tqdm(range(2 * num_test_batch)):
                with torch.no_grad():
                    label = ii % 2
                    graph_label = torch.tensor([label]).to('cuda').long()
                    start_time = time.time()
                    input_dict = {}
                    input_dict['is_sampling'] = True
                    input_dict['batch_size'] = self.test_conf.batch_size
                    input_dict['num_nodes_pmf'] = self.num_nodes_pmf_by_group[
                        graph_label.item()]
                    input_dict['graph_label'] = graph_label

                    A_tmp, label_tmp = model(input_dict)
                    A_tmp = A_tmp[0]
                    label_tmp = label_tmp[0]

                    label_tmp = label_tmp.long()

                    lower_part = torch.tril(A_tmp, diagonal=-1)

                    x = torch.zeros((A_tmp.shape[0], 3)).to(self.device)
                    x[list(range(A_tmp.shape[0])), label_tmp] = 1

                    edge_mask = (lower_part != 0).to(self.device)
                    edges = edge_mask.nonzero().transpose(0, 1).to(self.device)
                    edges_other_way = edges[[1, 0]]
                    edges = torch.cat([edges, edges_other_way],
                                      dim=-1).to(self.device)

                    batch = torch.zeros(A_tmp.shape[0]).long().to(self.device)

                    data = Bunch(x=x,
                                 edge_index=edges,
                                 batch=batch,
                                 y=graph_label,
                                 edge_weight=None)

                    n_nodes = batch.shape[0]
                    n_edges = edges.shape[1]

                    output = graph_classifier(data)

                    if not isinstance(output, tuple):
                        output = (output, )

                    graph_classification_loss, graph_classification_acc = classifier_loss(
                        data.y, *output)
                    graph_acc_count += graph_classification_acc / 100

                    acc_count_by_label[label] += graph_classification_acc / 100

                    print(graph_classification_acc, graph_label)

                    if ii % 100 == 99:
                        n_graphs_each = (ii + 1) / 2
                        print("\033[92m" +
                              "Class 0: %.3f ----  Class 1: %.3f" %
                              (acc_count_by_label[0] / n_graphs_each,
                               acc_count_by_label[1] / n_graphs_each) +
                              "\033[0m")

            logger.info('Average test time per mini-batch = {}'.format(
                np.mean(gen_run_time)))
            for label in [0, 1]:
                graph_acc_count = acc_count_by_label[label]
                logger.info('Class %s: ' % (label) +
                            'Conditional graph generation accuracy = {}'.
                            format(graph_acc_count / num_test_batch))

            graphs_gen = [get_graph(aa) for aa in A_pred]

        ### Visualize Generated Graphs
        if self.is_vis:
            num_col = self.vis_num_row
            num_row = int(np.ceil(self.num_vis / num_col))
            test_epoch = self.test_conf.test_model_name
            test_epoch = test_epoch[test_epoch.rfind('_') +
                                    1:test_epoch.find('.pth')]
            save_name = os.path.join(
                self.config.save_dir,
                '{}_gen_graphs_epoch_{}_block_{}_stride_{}.png'.format(
                    self.config.test.test_model_name[:-4], test_epoch,
                    self.block_size, self.stride))

            # remove isolated nodes for better visulization
            graphs_pred_vis = [
                copy.deepcopy(gg) for gg in graphs_gen[:self.num_vis]
            ]

            if self.better_vis:
                for gg in graphs_pred_vis:
                    gg.remove_nodes_from(list(nx.isolates(gg)))

            # display the largest connected component for better visualization
            vis_graphs = []
            for gg in graphs_pred_vis:
                CGs = [gg.subgraph(c) for c in nx.connected_components(gg)]
                CGs = sorted(CGs,
                             key=lambda x: x.number_of_nodes(),
                             reverse=True)
                vis_graphs += [CGs[0]]

            if self.is_single_plot:
                draw_graph_list(vis_graphs,
                                num_row,
                                num_col,
                                fname=save_name,
                                layout='spring')
            else:
                draw_graph_list_separate(vis_graphs,
                                         fname=save_name[:-4],
                                         is_single=True,
                                         layout='spring')

            save_name = os.path.join(self.config.save_dir, 'train_graphs.png')

            if self.is_single_plot:
                draw_graph_list(self.graphs_train[:self.num_vis],
                                num_row,
                                num_col,
                                fname=save_name,
                                layout='spring')
            else:
                draw_graph_list_separate(self.graphs_train[:self.num_vis],
                                         fname=save_name[:-4],
                                         is_single=True,
                                         layout='spring')

        ### Evaluation
        if self.config.dataset.name in ['lobster']:
            acc = eval_acc_lobster_graph(graphs_gen)
            logger.info(
                'Validity accuracy of generated graphs = {}'.format(acc))

        num_nodes_gen = [len(aa) for aa in graphs_gen]

        # Compared with Validation Set
        num_nodes_dev = [len(gg.nodes)
                         for gg in self.graphs_dev]  # shape B X 1
        mmd_degree_dev, mmd_clustering_dev, mmd_4orbits_dev, mmd_spectral_dev = evaluate(
            self.graphs_dev, graphs_gen, degree_only=False)
        mmd_num_nodes_dev = compute_mmd([np.bincount(num_nodes_dev)],
                                        [np.bincount(num_nodes_gen)],
                                        kernel=gaussian_emd)

        # Compared with Test Set
        num_nodes_test = [len(gg.nodes)
                          for gg in self.graphs_test]  # shape B X 1
        mmd_degree_test, mmd_clustering_test, mmd_4orbits_test, mmd_spectral_test = evaluate(
            self.graphs_test, graphs_gen, degree_only=False)
        mmd_num_nodes_test = compute_mmd([np.bincount(num_nodes_test)],
                                         [np.bincount(num_nodes_gen)],
                                         kernel=gaussian_emd)

        logger.info(
            "Validation MMD scores of #nodes/degree/clustering/4orbits/spectral are = {}/{}/{}/{}/{}"
            .format(mmd_num_nodes_dev, mmd_degree_dev, mmd_clustering_dev,
                    mmd_4orbits_dev, mmd_spectral_dev))
        logger.info(
            "Test MMD scores of #nodes/degree/clustering/4orbits/spectral are = {}/{}/{}/{}/{}"
            .format(mmd_num_nodes_test, mmd_degree_test, mmd_clustering_test,
                    mmd_4orbits_test, mmd_spectral_test))

        if self.config.dataset.name in ['lobster']:
            return mmd_degree_dev, mmd_clustering_dev, mmd_4orbits_dev, mmd_spectral_dev, mmd_degree_test, mmd_clustering_test, mmd_4orbits_test, mmd_spectral_test, acc
        else:
            return mmd_degree_dev, mmd_clustering_dev, mmd_4orbits_dev, mmd_spectral_dev, mmd_degree_test, mmd_clustering_test, mmd_4orbits_test, mmd_spectral_test