예제 #1
0
  def eval(self):
    eval_time = time.time()
    self.model_func.to(self.device)
    logger.debug("Starting on device={}".format(self.model_func.device))
    self.model_func.eval() 

    num_test_batch = int(np.ceil(self.samples / self.batch_size))

    A_pred = []

    for _ in tqdm(range(num_test_batch)):
      with torch.no_grad():        
        start_time = time.time()
        input_dict = {}
        input_dict['is_sampling']=True
        input_dict['batch_size']=self.batch_size
        num_node_dist = np.array(self.metrics['train']['node_dist'])
        num_node_dist = np.bincount(num_node_dist)
        input_dict['num_nodes_pmf']= num_node_dist / num_node_dist.sum()

        A_tmp = self.model_func(input_dict)
        
        A_pred += [aa.data.cpu().numpy() for aa in A_tmp]
    
    #Convert to networkx
    vis_graphs = [self.get_directed_graph(aa) for aa in A_pred]
    
    #Rank graphs by number of nodes
    ranked_arbos = [(gg, gg.number_of_nodes()) for gg in vis_graphs]
    ranked_arbos = sorted(ranked_arbos, key=lambda x: [1], reverse=True)

    #Only keep x samples
    ranked_arbos = [gg[0] for gg in ranked_arbos][:self.samples]


    if self.draw_settings in ['all', 'one']:

      if self.draw_settings == 'all':
        drawn_arbos = ranked_arbos
      elif self.draw_settings == 'one':
        drawn_arbos = [self.random_gen.choice(ranked_arbos)]

      save_fname = os.path.join(self.config.config_save_dir, 'sampled_trees.png')

      draw_graph_list_separate(
        drawn_arbos, 
        fname=save_fname[:-4],
        is_single=True, 
        layout='kamada'
        )
    
    logger.debug("Generated {} Tree [{:2.2f} s]".format(self.samples, time.time() - eval_time))

    return ranked_arbos
예제 #2
0
파일: gran_runner.py 프로젝트: texbomb/GRAN
    def test(self):
        self.config.save_dir = self.test_conf.test_model_dir

        ### Compute Erdos-Renyi baseline
        if self.config.test.is_test_ER:
            p_ER = sum([
                aa.number_of_edges() for aa in self.graphs_train
            ]) / sum([aa.number_of_nodes()**2 for aa in self.graphs_train])
            graphs_gen = [
                nx.fast_gnp_random_graph(self.max_num_nodes, p_ER, seed=ii)
                for ii in range(self.num_test_gen)
            ]
        else:
            ### load model
            model = eval(self.model_conf.name)(self.config)
            model_file = os.path.join(self.config.save_dir,
                                      self.test_conf.test_model_name)
            load_model(model, model_file, self.device)

            if self.use_gpu:
                model = nn.DataParallel(model,
                                        device_ids=self.gpus).to(self.device)

            model.eval()

            ### Generate Graphs
            A_pred = []
            num_nodes_pred = []
            alpha_list = []
            num_test_batch = int(
                np.ceil(self.num_test_gen / self.test_conf.batch_size))

            gen_run_time = []
            for ii in tqdm(range(num_test_batch)):
                with torch.no_grad():
                    start_time = time.time()
                    input_dict = {}
                    input_dict['is_sampling'] = True
                    input_dict['batch_size'] = self.test_conf.batch_size
                    input_dict['num_nodes_pmf'] = self.num_nodes_pmf_train
                    A_tmp, alpha_temp = model(input_dict)
                    gen_run_time += [time.time() - start_time]
                    A_pred += [aa.data.cpu().numpy() for aa in A_tmp]
                    num_nodes_pred += [aa.shape[0] for aa in A_tmp]
                    alpha_list += [aa.data.cpu().numpy() for aa in alpha_temp]

            logger.info('Average test time per mini-batch = {}'.format(
                np.mean(gen_run_time)))

            graphs_gen = [
                get_graph(aa, alpha_list[i]) for i, aa in enumerate(A_pred)
            ]

        ### Visualize Generated Graphs
        if self.is_vis:
            num_col = self.vis_num_row
            num_row = int(np.ceil(self.num_vis / num_col))
            test_epoch = self.test_conf.test_model_name
            test_epoch = test_epoch[test_epoch.rfind('_') +
                                    1:test_epoch.find('.pth')]
            save_name = os.path.join(
                self.config.save_dir,
                '{}_gen_graphs_epoch_{}_block_{}_stride_{}.png'.format(
                    self.config.test.test_model_name[:-4], test_epoch,
                    self.block_size, self.stride))

            # remove isolated nodes for better visulization
            graphs_pred_vis = [
                copy.deepcopy(gg) for gg in graphs_gen[:self.num_vis]
            ]

            # Saves Graphs
            for i, gg in enumerate(graphs_pred_vis):
                G = gg
                name = os.path.join(
                    self.config.save_dir,
                    '{}_gen_graphs_epoch_{}_{}.pickle'.format(
                        self.config.test.test_model_name[:-4], test_epoch, i))
                with open(name, 'wb') as handle:
                    pickle.dump(G, handle)

            if self.better_vis:
                for gg in graphs_pred_vis:
                    gg.remove_nodes_from(list(nx.isolates(gg)))

            # display the largest connected component for better visualization
            vis_graphs = []
            for gg in graphs_pred_vis:
                CGs = [gg.subgraph(c) for c in nx.connected_components(gg)]
                CGs = sorted(CGs,
                             key=lambda x: x.number_of_nodes(),
                             reverse=True)
                vis_graphs += [CGs[0]]

            if self.is_single_plot:
                draw_graph_list(vis_graphs,
                                num_row,
                                num_col,
                                fname=save_name,
                                layout='spring')
            else:
                draw_graph_list_separate(vis_graphs,
                                         fname=save_name[:-4],
                                         is_single=True,
                                         layout='spring')

            save_name = os.path.join(self.config.save_dir, 'train_graphs.png')

            if self.is_single_plot:
                draw_graph_list(self.graphs_train[:self.num_vis],
                                num_row,
                                num_col,
                                fname=save_name,
                                layout='spring')
            else:
                draw_graph_list_separate(self.graphs_train[:self.num_vis],
                                         fname=save_name[:-4],
                                         is_single=True,
                                         layout='spring')

        ### Evaluation
        if self.config.dataset.name in ['lobster']:
            acc = eval_acc_lobster_graph(graphs_gen)
            logger.info(
                'Validity accuracy of generated graphs = {}'.format(acc))

        num_nodes_gen = [len(aa) for aa in graphs_gen]

        # Compared with Validation Set
        num_nodes_dev = [len(gg.nodes)
                         for gg in self.graphs_dev]  # shape B X 1
        mmd_degree_dev, mmd_clustering_dev, mmd_4orbits_dev, mmd_spectral_dev = evaluate(
            self.graphs_dev, graphs_gen, degree_only=False)
        mmd_num_nodes_dev = compute_mmd([np.bincount(num_nodes_dev)],
                                        [np.bincount(num_nodes_gen)],
                                        kernel=gaussian_emd)

        # Compared with Test Set
        num_nodes_test = [len(gg.nodes)
                          for gg in self.graphs_test]  # shape B X 1
        mmd_degree_test, mmd_clustering_test, mmd_4orbits_test, mmd_spectral_test = evaluate(
            self.graphs_test, graphs_gen, degree_only=False)
        mmd_num_nodes_test = compute_mmd([np.bincount(num_nodes_test)],
                                         [np.bincount(num_nodes_gen)],
                                         kernel=gaussian_emd)

        logger.info(
            "Validation MMD scores of #nodes/degree/clustering/4orbits/spectral are = {}/{}/{}/{}/{}"
            .format(mmd_num_nodes_dev, mmd_degree_dev, mmd_clustering_dev,
                    mmd_4orbits_dev, mmd_spectral_dev))
        logger.info(
            "Test MMD scores of #nodes/degree/clustering/4orbits/spectral are = {}/{}/{}/{}/{}"
            .format(mmd_num_nodes_test, mmd_degree_test, mmd_clustering_test,
                    mmd_4orbits_test, mmd_spectral_test))

        if self.config.dataset.name in ['lobster']:
            return mmd_degree_dev, mmd_clustering_dev, mmd_4orbits_dev, mmd_spectral_dev, mmd_degree_test, mmd_clustering_test, mmd_4orbits_test, mmd_spectral_test, acc
        else:
            return mmd_degree_dev, mmd_clustering_dev, mmd_4orbits_dev, mmd_spectral_dev, mmd_degree_test, mmd_clustering_test, mmd_4orbits_test, mmd_spectral_test
예제 #3
0
    def test(self):
        self.config.save_dir = self.test_conf.test_model_dir

        ### Compute Erdos-Renyi baseline
        # if self.config.test.is_test_ER:
        p_ER = sum([aa.number_of_edges()
                    for aa in self.graphs_train]) / sum([aa.number_of_nodes()**2 for aa in self.graphs_train])
        # graphs_baseline = [nx.fast_gnp_random_graph(self.max_num_nodes, p_ER, seed=ii) for ii in range(self.num_test_gen)]
        graphs_gen = [nx.fast_gnp_random_graph(self.max_num_nodes, p_ER, seed=ii) for ii in range(self.num_test_gen)]
        temp = []
        for G in graphs_gen:
            G.remove_nodes_from(list(nx.isolates(G)))
            if G is not None:
                #  take the largest connected component
                CGs = [G.subgraph(c) for c in nx.connected_components(G)]
                CGs = sorted(CGs, key=lambda x: x.number_of_nodes(), reverse=True)
                temp.append(CGs[0])
        # graphs_gen = temp
        graphs_baseline = temp

        # else:
        ### load model
        model = eval(self.model_conf.name)(self.config)
        model_file = os.path.join(self.config.save_dir, self.test_conf.test_model_name)
        load_model(model, model_file, self.device)

        if self.use_gpu:
            model = nn.DataParallel(model, device_ids=self.gpus).to(self.device)

        model.eval()

        ### Generate Graphs
        A_pred = []
        node_label_pred = []
        num_nodes_pred = []
        num_test_batch = int(np.ceil(self.num_test_gen / self.test_conf.batch_size))

        gen_run_time = []
        for ii in tqdm(range(num_test_batch)):
            with torch.no_grad():
                start_time = time.time()
                input_dict = {}
                input_dict['is_sampling'] = True
                input_dict['batch_size'] = self.test_conf.batch_size
                input_dict['num_nodes_pmf'] = self.num_nodes_pmf_train
                A_tmp, node_label_tmp = model(input_dict)
                gen_run_time += [time.time() - start_time]
                A_pred += [aa.data.cpu().numpy() for aa in A_tmp]
                node_label_pred += [ll.data.cpu().numpy() for ll in node_label_tmp]
                num_nodes_pred += [aa.shape[0] for aa in A_tmp]
        # print(len(A_pred), type(A_pred[0]))

        logger.info('Average test time per mini-batch = {}'.format(np.mean(gen_run_time)))

        # print(A_pred[0].shape,
        #       get_graph(A_pred[0]).number_of_nodes(),
        #       get_graph_with_labels(A_pred[0], node_label_pred[0]).number_of_nodes())
        # print(A_pred[0])
        # return
        # graphs_gen = [get_graph(aa) for aa in A_pred]
        graphs_gen = [get_graph_with_labels(aa, ll) for aa, ll in zip(A_pred, node_label_pred)]
        valid_pctg, bipartite_pctg = calculate_validity(graphs_gen)  # for adding bipartite graph attribute

        # return

        ### Visualize Generated Graphs
        if self.is_vis:
            num_col = self.vis_num_row
            num_row = int(np.ceil(self.num_vis / num_col))
            test_epoch = self.test_conf.test_model_name
            test_epoch = test_epoch[test_epoch.rfind('_') + 1:test_epoch.find('.pth')]
            save_name = os.path.join(self.config.save_dir, '{}_gen_graphs_epoch_{}_block_{}_stride_{}.png'.format(
                self.config.test.test_model_name[:-4], test_epoch, self.block_size, self.stride))

            # remove isolated nodes for better visulization
            # graphs_pred_vis = [copy.deepcopy(gg) for gg in graphs_gen[:self.num_vis]]
            graphs_pred_vis = [copy.deepcopy(gg) for gg in graphs_gen if not gg.graph['bipartite']]
            logger.info('Number of not bipartite graphs: {} / {}'.format(len(graphs_pred_vis), len(graphs_gen)))
            # if self.better_vis:
            #     for gg in graphs_pred_vis:
            #         gg.remove_nodes_from(list(nx.isolates(gg)))

            # # display the largest connected component for better visualization
            # vis_graphs = []
            # for gg in graphs_pred_vis:
            #     CGs = [gg.subgraph(c) for c in nx.connected_components(gg)] # nx.subgraph makes a graph frozen!
            #     CGs = sorted(CGs, key=lambda x: x.number_of_nodes(), reverse=True)
            #     vis_graphs += [CGs[0]]
            vis_graphs = graphs_pred_vis

            if self.is_single_plot:
                draw_graph_list(vis_graphs, num_row, num_col, fname=save_name, layout='spring')
            else:  #XD: using this for now
                draw_graph_list_separate(vis_graphs, fname=save_name[:-4], is_single=True, layout='spring')

            save_name = os.path.join(self.config.save_dir, 'train_graphs.png')
            if self.is_single_plot:
                draw_graph_list(self.graphs_train[:self.num_vis], num_row, num_col, fname=save_name, layout='spring')
                print('training single plot saved at:', save_name)
            else:  #XD: using this for now
                graph_list_train = [get_graph_from_nx(G) for G in self.graphs_train[:self.num_vis]]
                draw_graph_list_separate(graph_list_train, fname=save_name[:-4], is_single=True, layout='spring')
                print('training plots saved individually at:', save_name[:-4])
        return

        ### Evaluation
        if self.config.dataset.name in ['lobster']:
            acc = eval_acc_lobster_graph(graphs_gen)
            logger.info('Validity accuracy of generated graphs = {}'.format(acc))
        '''=====XD====='''
        ## graphs_gen = [generate_random_baseline_single(len(aa)) for aa in graphs_gen]  # use this line for random baseline MMD scores. Remember to comment it later!
        # draw_hists(self.graphs_test, graphs_baseline, graphs_gen)
        valid_pctg, bipartite_pctg = calculate_validity(graphs_gen)
        # logger.info('Generated {} graphs, valid percentage = {:.2f}, bipartite percentage = {:.2f}'.format(
        #     len(graphs_gen), valid_pctg, bipartite_pctg))
        # # return
        '''=====XD====='''

        num_nodes_gen = [len(aa) for aa in graphs_gen]

        # # Compared with Validation Set
        # num_nodes_dev = [len(gg.nodes) for gg in self.graphs_dev]  # shape B X 1
        # mmd_degree_dev, mmd_clustering_dev, mmd_4orbits_dev, mmd_spectral_dev, mmd_mean_degree_dev, mmd_max_degree_dev, mmd_mean_centrality_dev, mmd_assortativity_dev, mmd_mean_degree_connectivity_dev = evaluate(self.graphs_dev, graphs_gen, degree_only=False)
        # mmd_num_nodes_dev = compute_mmd([np.bincount(num_nodes_dev)], [np.bincount(num_nodes_gen)], kernel=gaussian_emd)
        # logger.info("Validation MMD scores of #nodes/degree/clustering/4orbits/spectral/... are = {:.5f}/{:.5f}/{:.5f}/{:.5f}/{:.5f}/{:.5f}/{:.5f}/{:.5f}/{:.5f}/{:.5f}".format(mmd_num_nodes_dev, mmd_degree_dev, mmd_clustering_dev, mmd_4orbits_dev, mmd_spectral_dev, mmd_mean_degree_dev, mmd_max_degree_dev, mmd_mean_centrality_dev, mmd_assortativity_dev, mmd_mean_degree_connectivity_dev))

        # Compared with Test Set
        num_nodes_test = [len(gg.nodes) for gg in self.graphs_test]  # shape B X 1
        mmd_degree_test, mmd_clustering_test, mmd_4orbits_test, mmd_spectral_test, mmd_mean_degree_test, mmd_max_degree_test, mmd_mean_centrality_test, mmd_assortativity_test, mmd_mean_degree_connectivity_test = evaluate(
            self.graphs_test, graphs_gen, degree_only=False)
        mmd_num_nodes_test = compute_mmd(
            [np.bincount(num_nodes_test)], [np.bincount(num_nodes_gen)], kernel=gaussian_emd)

        logger.info(
            "Test MMD scores of #nodes/degree/clustering/4orbits/spectral/... are = {:.5f}/{:.5f}/{:.5f}/{:.5f}/{:.5f}/{:.5f}/{:.5f}/{:.5f}/{:.5f}/{:.5f}".
            format(mmd_num_nodes_test, mmd_degree_test, mmd_clustering_test, mmd_4orbits_test, mmd_spectral_test,
                   mmd_mean_degree_test, mmd_max_degree_test, mmd_mean_centrality_test, mmd_assortativity_test,
                   mmd_mean_degree_connectivity_test))
  def test(self):
    self.config.save_dir_train = self.test_conf.test_model_dir

    if not self.config.test.is_test_ER:
      ### load model
      model = eval(self.model_conf.name)(self.config)
      model_file = os.path.join(self.config.save_dir_train, self.test_conf.test_model_name)
      load_model(model, model_file, self.device)
      if self.use_gpu:
        model = nn.DataParallel(model, device_ids=self.gpus).to(self.device)
      model.eval()

      if hasattr(self.config, 'complete_graph_model'):
        complete_graph_model = eval(self.config.complete_graph_model.name)(self.config.complete_graph_model)
        complete_graph_model_file = os.path.join(self.config.complete_graph_model.test_model_dir,
                                                 self.config.complete_graph_model.test_model_name)
        load_model(complete_graph_model, complete_graph_model_file, self.device)
        if self.use_gpu:
          complete_graph_model = nn.DataParallel(complete_graph_model, device_ids=self.gpus).to(self.device)
        complete_graph_model.eval()

    if self.config.test.is_test_ER or not hasattr(self.config.test, 'hard_multi') or not self.config.test.hard_multi:
      hard_thre_list = [None]
    else:
      hard_thre_list = np.arange(0.5, 1, 0.1)

    for test_hard_idx, hard_thre in enumerate(hard_thre_list):
      if self.config.test.is_test_ER:
        ### Compute Erdos-Renyi baseline
        p_ER = sum([aa.number_of_edges() for aa in self.graphs_train]) / sum([aa.number_of_nodes() ** 2 for aa in self.graphs_train])
        graphs_gen = [nx.fast_gnp_random_graph(self.max_num_nodes, p_ER, seed=ii) for ii in range(self.num_test_gen)]
      else:
        logger.info('Test pass {}. Hard threshold {}'.format(test_hard_idx, hard_thre))
        ### Generate Graphs
        A_pred = []
        num_nodes_pred = []
        num_test_batch = int(np.ceil(self.num_test_gen / self.test_conf.batch_size))

        gen_run_time = []
        for ii in tqdm(range(num_test_batch)):
          with torch.no_grad():
            start_time = time.time()
            input_dict = {}
            input_dict['is_sampling'] = True
            input_dict['batch_size'] = self.test_conf.batch_size
            input_dict['num_nodes_pmf'] = self.num_nodes_pmf_train
            input_dict['hard_thre'] = hard_thre
            A_tmp = model(input_dict)

            if hasattr(self.config, 'complete_graph_model'):
              final_A_list = []
              for batch_idx in range(len(A_tmp)):
                new_pmf = torch.zeros(len(self.num_nodes_pmf_train))
                max_prob = 0.
                max_prob_num_nodes = None
                for num_nodes, prob in enumerate(self.num_nodes_pmf_train):
                  if prob == 0.:
                    continue
                  tmp_data = {}
                  A_tmp_tmp = A_tmp[batch_idx][:num_nodes, :num_nodes]
                  tmp_data['adj'] = F.pad(
                    A_tmp_tmp,
                    (0, self.config.complete_graph_model.model.max_num_nodes-num_nodes, 0, 0),
                    'constant', value=.0)[None, None, ...]

                  adj = torch.tril(A_tmp_tmp, diagonal=-1)
                  adj = adj + adj.transpose(0, 1)
                  edges = adj.to_sparse().coalesce().indices()
                  tmp_data['edges'] = edges.t()
                  tmp_data['subgraph_idx'] = torch.zeros(num_nodes).long().to(self.device, non_blocking=True)

                  tmp_logit = complete_graph_model(tmp_data)
                  new_pmf[num_nodes] = torch.sigmoid(tmp_logit).item()

                  if new_pmf[num_nodes] > max_prob:
                    max_prob = new_pmf[num_nodes]
                    max_prob_num_nodes = num_nodes

                  if new_pmf[num_nodes] <= 0.9:
                    new_pmf[num_nodes] = 0.

                if (new_pmf == 0.).all():
                  logger.info('(new_pmf == 0.).all(), use {} nodes with max prob {}'.format(max_prob_num_nodes, max_prob))
                  final_num_nodes = max_prob_num_nodes
                else:
                  final_num_nodes = torch.multinomial(new_pmf, 1).item()
                final_A_list.append(
                  A_tmp_tmp[:final_num_nodes, :final_num_nodes]
                )
              A_tmp = final_A_list
            gen_run_time += [time.time() - start_time]
            A_pred += [aa.cpu().numpy() for aa in A_tmp]
            num_nodes_pred += [aa.shape[0] for aa in A_tmp]
        print('num_nodes_pred', num_nodes_pred)
        logger.info('Average test time per mini-batch = {}'.format(
          np.mean(gen_run_time)))

        graphs_gen = [get_graph(aa) for aa in A_pred]

      ### Visualize Generated Graphs
      if self.is_vis:
        num_col = self.vis_num_row
        num_row = self.num_vis // num_col
        test_epoch = self.test_conf.test_model_name
        test_epoch = test_epoch[test_epoch.rfind('_') + 1:test_epoch.find('.pth')]
        if hard_thre is not None:
          save_name = os.path.join(self.config.save_dir_train, '{}_gen_graphs_epoch_{}_hard_{}.png'.format(
            self.config.test.test_model_name[:-4], test_epoch,
            int(round(hard_thre*10))))
          save_name2 = os.path.join(self.config.save_dir,
                                   '{}_gen_graphs_epoch_{}_hard_{}.png'.format(
                                     self.config.test.test_model_name[:-4], test_epoch,
                                     int(round(hard_thre * 10))))
        else:
          save_name = os.path.join(self.config.save_dir_train,
                                   '{}_gen_graphs_epoch_{}.png'.format(
                                     self.config.test.test_model_name[:-4], test_epoch))
          save_name2 = os.path.join(self.config.save_dir,
                                    '{}_gen_graphs_epoch_{}.png'.format(
                                      self.config.test.test_model_name[:-4], test_epoch))

        # remove isolated nodes for better visulization
        graphs_pred_vis = [copy.deepcopy(gg) for gg in graphs_gen[:self.num_vis]]

        if self.better_vis:
          # actually not necessary with the following largest connected component selection
          for gg in graphs_pred_vis:
            gg.remove_nodes_from(list(nx.isolates(gg)))

        # display the largest connected component for better visualization
        vis_graphs = []
        for gg in graphs_pred_vis:
          if self.better_vis:
            CGs = [gg.subgraph(c) for c in nx.connected_components(gg)]
            CGs = sorted(CGs, key=lambda x: x.number_of_nodes(), reverse=True)
            vis_graphs += [CGs[0]]
          else:
            vis_graphs += [gg]
        print('number of nodes after better vis', [tmp_g.number_of_nodes() for tmp_g in vis_graphs])

        if self.is_single_plot:
          # draw_graph_list(vis_graphs, num_row, num_col, fname=save_name, layout='spring')
          draw_graph_list(vis_graphs, num_row, num_col, fname=save_name2, layout='spring')
        else:
          # draw_graph_list_separate(vis_graphs, fname=save_name[:-4], is_single=True, layout='spring')
          draw_graph_list_separate(vis_graphs, fname=save_name2[:-4], is_single=True, layout='spring')

        if test_hard_idx == 0:
          save_name = os.path.join(self.config.save_dir_train, 'train_graphs.png')

          if self.is_single_plot:
            draw_graph_list(
              self.graphs_train[:self.num_vis],
              num_row,
              num_col,
              fname=save_name,
              layout='spring')
          else:
            draw_graph_list_separate(
              self.graphs_train[:self.num_vis],
              fname=save_name[:-4],
              is_single=True,
              layout='spring')

      ### Evaluation
      if self.config.dataset.name in ['lobster']:
        acc = eval_acc_lobster_graph(graphs_gen)
        logger.info('Validity accuracy of generated graphs = {}'.format(acc))

      num_nodes_gen = [len(aa) for aa in graphs_gen]

      # Compared with Validation Set
      num_nodes_dev = [len(gg.nodes) for gg in self.graphs_dev]  # shape B X 1
      mmd_degree_dev, mmd_clustering_dev, mmd_4orbits_dev, mmd_spectral_dev = evaluate(self.graphs_dev, graphs_gen,
                                                                                       degree_only=False)
      mmd_num_nodes_dev = compute_mmd([np.bincount(num_nodes_dev)], [np.bincount(num_nodes_gen)], kernel=gaussian_emd)

      # Compared with Test Set
      num_nodes_test = [len(gg.nodes) for gg in self.graphs_test]  # shape B X 1
      mmd_degree_test, mmd_clustering_test, mmd_4orbits_test, mmd_spectral_test = evaluate(self.graphs_test, graphs_gen,
                                                                                           degree_only=False)
      mmd_num_nodes_test = compute_mmd([np.bincount(num_nodes_test)], [np.bincount(num_nodes_gen)], kernel=gaussian_emd)

      logger.info(
        "Validation MMD scores of #nodes/degree/clustering/4orbits/spectral are = {:.4E}/{:.4E}/{:.4E}/{:.4E}/{:.4E}".format(Decimal(mmd_num_nodes_dev),
                                                                                                   Decimal(mmd_degree_dev),
                                                                                                   Decimal(mmd_clustering_dev),
                                                                                                   Decimal(mmd_4orbits_dev),
                                                                                                   Decimal(mmd_spectral_dev)))
      logger.info(
        "Test MMD scores of #nodes/degree/clustering/4orbits/spectral are = {:.4E}/{:.4E}/{:.4E}/{:.4E}/{:.4E}".format(Decimal(mmd_num_nodes_test),
                                                                                                   Decimal(mmd_degree_test),
                                                                                                   Decimal(mmd_clustering_test),
                                                                                                   Decimal(mmd_4orbits_test),
                                                                                                   Decimal(mmd_spectral_test)))
예제 #5
0
    def test(self):
        self.config.save_dir_train = self.test_conf.test_model_dir

        ### test dataset
        test_dataset = eval(self.dataset_conf.loader_name)(self.config,
                                                           self.graphs_test,
                                                           tag='test')
        test_loader = torch.utils.data.DataLoader(
            test_dataset,
            batch_size=self.test_conf.batch_size,
            shuffle=False,
            num_workers=self.train_conf.num_workers,
            collate_fn=test_dataset.collate_fn,
            drop_last=False)

        ### load model
        args = self.config.model
        n_labels = self.dataset_conf.max_m + self.dataset_conf.max_n
        G = define_G(args.nz, args.ngf, args.netG, args.final_activation,
                     args.norm_G)
        model_file_G = os.path.join(self.config.save_dir_train,
                                    self.test_conf.test_model_name)

        load_model(G, model_file_G, self.device)
        if self.use_gpu:
            G = G.cuda()  #nn.DataParallel(G).to(self.device)
        G.train()

        if not hasattr(self.config.test,
                       'hard_multi') or not self.config.test.hard_multi:
            hard_thre_list = [None]
        else:
            hard_thre_list = np.arange(0.5, 1, 0.1)

        for test_hard_idx, hard_thre in enumerate(hard_thre_list):
            logger.info('Test pass {}. Hard threshold {}'.format(
                test_hard_idx, hard_thre))
            ### Generate Graphs
            A_pred = []
            gen_run_time = []

            for batch_data in test_loader:
                # asserted in arg helper
                ff = 0

                with torch.no_grad():
                    data = {}
                    data['adj'] = batch_data[ff]['adj'].pin_memory().to(
                        self.config.device, non_blocking=True)
                    data['m'] = batch_data[ff]['m'].to(self.config.device,
                                                       non_blocking=True)
                    data['n'] = batch_data[ff]['n'].to(self.config.device,
                                                       non_blocking=True)

                    batch_size = data['adj'].size(0)

                    i_onehot = torch.zeros(
                        (batch_size, self.dataset_conf.max_m),
                        requires_grad=True).pin_memory().to(self.config.device,
                                                            non_blocking=True)
                    i_onehot.scatter_(1, data['m'][:, None] - 1, 1)
                    j_onehot = torch.zeros(
                        (batch_size, self.dataset_conf.max_n),
                        requires_grad=True).pin_memory().to(self.config.device,
                                                            non_blocking=True)
                    j_onehot.scatter_(1, data['n'][:, None] - 1, 1)
                    y_onehot = torch.cat((i_onehot, j_onehot), dim=1)

                    if args.nz > n_labels:
                        noise = torch.randn(
                            (batch_size, args.nz - n_labels, 1, 1),
                            requires_grad=True).to(self.config.device,
                                                   non_blocking=True)
                        z_input = torch.cat(
                            (y_onehot.view(batch_size, n_labels, 1, 1), noise),
                            dim=1)
                    else:
                        z_input = y_onehot.view(batch_size, n_labels, 1, 1)

                    start_time = time.time()
                    output = G(z_input).squeeze(1)  # (B, 1, n, n)
                    if self.model_conf.final_activation == 'tanh':
                        output = (output + 1) / 2
                    if self.model_conf.is_sym:
                        output = torch.tril(output, diagonal=-1)
                        output = output + output.transpose(1, 2)
                    gen_run_time += [time.time() - start_time]

                    if hard_thre is not None:
                        A_pred += [(output[batch_idx, ...] >
                                    hard_thre).long().cpu().numpy()
                                   for batch_idx in range(batch_size)]
                    else:
                        A_pred += [
                            torch.bernoulli(output[batch_idx,
                                                   ...]).long().cpu().numpy()
                            for batch_idx in range(batch_size)
                        ]

            logger.info('Average test time per mini-batch = {}'.format(
                np.mean(gen_run_time)))

            graphs_gen = [get_graph(aa) for aa in A_pred]

            ### Visualize Generated Graphs
            if self.is_vis:
                num_col = self.vis_num_row
                num_row = self.num_vis // num_col
                test_epoch = self.test_conf.test_model_name
                test_epoch = test_epoch[test_epoch.rfind('_') +
                                        1:test_epoch.find('.pth')]
                if hard_thre is not None:
                    save_name = os.path.join(
                        self.config.save_dir_train,
                        '{}_gen_graphs_epoch_{}_hard_{}.png'.format(
                            self.config.test.test_model_name[:-4], test_epoch,
                            int(round(hard_thre * 10))))
                    save_name2 = os.path.join(
                        self.config.save_dir,
                        '{}_gen_graphs_epoch_{}_hard_{}.png'.format(
                            self.config.test.test_model_name[:-4], test_epoch,
                            int(round(hard_thre * 10))))
                else:
                    save_name = os.path.join(
                        self.config.save_dir_train,
                        '{}_gen_graphs_epoch_{}.png'.format(
                            self.config.test.test_model_name[:-4], test_epoch))
                    save_name2 = os.path.join(
                        self.config.save_dir,
                        '{}_gen_graphs_epoch_{}.png'.format(
                            self.config.test.test_model_name[:-4], test_epoch))

                # remove isolated nodes for better visulization
                graphs_pred_vis = [
                    copy.deepcopy(gg) for gg in graphs_gen[:self.num_vis]
                ]

                if self.better_vis:
                    # actually not necessary with the following largest connected component selection
                    for gg in graphs_pred_vis:
                        gg.remove_nodes_from(list(nx.isolates(gg)))

                # display the largest connected component for better visualization
                vis_graphs = []
                for gg in graphs_pred_vis:
                    if self.better_vis:
                        CGs = [
                            gg.subgraph(c) for c in nx.connected_components(gg)
                        ]
                        CGs = sorted(CGs,
                                     key=lambda x: x.number_of_nodes(),
                                     reverse=True)
                        vis_graphs += [CGs[0]]
                    else:
                        vis_graphs += [gg]
                print('number of nodes after better vis',
                      [tmp_g.number_of_nodes() for tmp_g in vis_graphs])

                if self.is_single_plot:
                    # draw_graph_list(vis_graphs, num_row, num_col, fname=save_name, layout='spring')
                    draw_graph_list(vis_graphs,
                                    num_row,
                                    num_col,
                                    fname=save_name2,
                                    layout='spring')
                else:
                    # draw_graph_list_separate(vis_graphs, fname=save_name[:-4], is_single=True, layout='spring')
                    draw_graph_list_separate(vis_graphs,
                                             fname=save_name2[:-4],
                                             is_single=True,
                                             layout='spring')

                if test_hard_idx == 0:
                    save_name = os.path.join(self.config.save_dir_train,
                                             'train_graphs.png')

                    if self.is_single_plot:
                        draw_graph_list(self.graphs_train[:self.num_vis],
                                        num_row,
                                        num_col,
                                        fname=save_name,
                                        layout='spring')
                    else:
                        draw_graph_list_separate(
                            self.graphs_train[:self.num_vis],
                            fname=save_name[:-4],
                            is_single=True,
                            layout='spring')

            ### Evaluation
            if self.config.dataset.name in ['lobster']:
                acc = eval_acc_lobster_graph(graphs_gen)
                logger.info(
                    'Validity accuracy of generated graphs = {}'.format(acc))

            num_nodes_gen = [len(aa) for aa in graphs_gen]

            # Compared with Validation Set
            num_nodes_dev = [len(gg.nodes)
                             for gg in self.graphs_dev]  # shape B X 1
            mmd_degree_dev, mmd_clustering_dev, mmd_4orbits_dev, mmd_spectral_dev = evaluate(
                self.graphs_dev, graphs_gen, degree_only=False)
            mmd_num_nodes_dev = compute_mmd([np.bincount(num_nodes_dev)],
                                            [np.bincount(num_nodes_gen)],
                                            kernel=gaussian_emd)

            # Compared with Test Set
            num_nodes_test = [len(gg.nodes)
                              for gg in self.graphs_test]  # shape B X 1
            mmd_degree_test, mmd_clustering_test, mmd_4orbits_test, mmd_spectral_test = evaluate(
                self.graphs_test, graphs_gen, degree_only=False)
            mmd_num_nodes_test = compute_mmd([np.bincount(num_nodes_test)],
                                             [np.bincount(num_nodes_gen)],
                                             kernel=gaussian_emd)

            logger.info(
                "Validation MMD scores of #nodes/degree/clustering/4orbits/spectral are = {:.4E}/{:.4E}/{:.4E}/{:.4E}/{:.4E}"
                .format(Decimal(mmd_num_nodes_dev), Decimal(mmd_degree_dev),
                        Decimal(mmd_clustering_dev), Decimal(mmd_4orbits_dev),
                        Decimal(mmd_spectral_dev)))
            logger.info(
                "Test MMD scores of #nodes/degree/clustering/4orbits/spectral are = {:.4E}/{:.4E}/{:.4E}/{:.4E}/{:.4E}"
                .format(Decimal(mmd_num_nodes_test), Decimal(mmd_degree_test),
                        Decimal(mmd_clustering_test),
                        Decimal(mmd_4orbits_test), Decimal(mmd_spectral_test)))
예제 #6
0
    def test(self):
        self.config.save_dir = self.test_conf.test_model_dir

        ### Compute Erdos-Renyi baseline
        if self.config.test.is_test_ER:
            p_ER = sum([
                aa.number_of_edges() for aa in self.graphs_train
            ]) / sum([aa.number_of_nodes()**2 for aa in self.graphs_train])
            graphs_gen = [
                nx.fast_gnp_random_graph(self.max_num_nodes, p_ER, seed=ii)
                for ii in range(self.num_test_gen)
            ]
        else:
            ### load model
            model = eval(self.model_conf.name)(self.config)
            model_file = os.path.join(self.config.save_dir,
                                      self.test_conf.test_model_name)
            load_model(model, model_file, self.device)

            # create graph classifier
            graph_classifier = GraphSAGE(3, 2, 3, 32, 'add')
            # graph_classifier = DiffPool(3, 2, max_num_nodes=630)
            # graph_classifier = DGCNN(3, 2, 'PROTEINS_full')
            graph_classifier.load_state_dict(
                torch.load('output/MODEL_PROTEINS.pkl'))

            if self.use_gpu:
                model = nn.DataParallel(model,
                                        device_ids=self.gpus).to(self.device)
                graph_classifier = graph_classifier.to(self.device)

            model.eval()
            graph_classifier.eval()

            ### Generate Graphs
            A_pred = []
            num_nodes_pred = []
            num_test_batch = 5000

            gen_run_time = []
            graph_acc_count = 0
            # for ii in tqdm(range(1)):
            #     with torch.no_grad():
            #         start_time = time.time()
            #         input_dict = {}
            #         input_dict['is_sampling'] = True
            #         input_dict['batch_size'] = self.test_conf.batch_size
            #         input_dict['num_nodes_pmf'] = self.num_nodes_pmf_train
            #         A_tmp, label_tmp = model(input_dict)
            #         gen_run_time += [time.time() - start_time]
            #         A_pred += [aa.data.cpu().numpy() for aa in A_tmp]
            #         num_nodes_pred += [aa.shape[0] for aa in A_tmp]

            from classifier.losses import MulticlassClassificationLoss
            classifier_loss = MulticlassClassificationLoss()

            ps = []

            acc_count_by_label = {0: 0, 1: 0}

            graph_acc_count = 0
            for ii in tqdm(range(2 * num_test_batch)):
                with torch.no_grad():
                    label = ii % 2
                    graph_label = torch.tensor([label]).to('cuda').long()
                    start_time = time.time()
                    input_dict = {}
                    input_dict['is_sampling'] = True
                    input_dict['batch_size'] = self.test_conf.batch_size
                    input_dict['num_nodes_pmf'] = self.num_nodes_pmf_by_group[
                        graph_label.item()]
                    input_dict['graph_label'] = graph_label

                    A_tmp, label_tmp = model(input_dict)
                    A_tmp = A_tmp[0]
                    label_tmp = label_tmp[0]

                    label_tmp = label_tmp.long()

                    lower_part = torch.tril(A_tmp, diagonal=-1)

                    x = torch.zeros((A_tmp.shape[0], 3)).to(self.device)
                    x[list(range(A_tmp.shape[0])), label_tmp] = 1

                    edge_mask = (lower_part != 0).to(self.device)
                    edges = edge_mask.nonzero().transpose(0, 1).to(self.device)
                    edges_other_way = edges[[1, 0]]
                    edges = torch.cat([edges, edges_other_way],
                                      dim=-1).to(self.device)

                    batch = torch.zeros(A_tmp.shape[0]).long().to(self.device)

                    data = Bunch(x=x,
                                 edge_index=edges,
                                 batch=batch,
                                 y=graph_label,
                                 edge_weight=None)

                    n_nodes = batch.shape[0]
                    n_edges = edges.shape[1]

                    output = graph_classifier(data)

                    if not isinstance(output, tuple):
                        output = (output, )

                    graph_classification_loss, graph_classification_acc = classifier_loss(
                        data.y, *output)
                    graph_acc_count += graph_classification_acc / 100

                    acc_count_by_label[label] += graph_classification_acc / 100

                    print(graph_classification_acc, graph_label)

                    if ii % 100 == 99:
                        n_graphs_each = (ii + 1) / 2
                        print("\033[92m" +
                              "Class 0: %.3f ----  Class 1: %.3f" %
                              (acc_count_by_label[0] / n_graphs_each,
                               acc_count_by_label[1] / n_graphs_each) +
                              "\033[0m")

            logger.info('Average test time per mini-batch = {}'.format(
                np.mean(gen_run_time)))
            for label in [0, 1]:
                graph_acc_count = acc_count_by_label[label]
                logger.info('Class %s: ' % (label) +
                            'Conditional graph generation accuracy = {}'.
                            format(graph_acc_count / num_test_batch))

            graphs_gen = [get_graph(aa) for aa in A_pred]

        ### Visualize Generated Graphs
        if self.is_vis:
            num_col = self.vis_num_row
            num_row = int(np.ceil(self.num_vis / num_col))
            test_epoch = self.test_conf.test_model_name
            test_epoch = test_epoch[test_epoch.rfind('_') +
                                    1:test_epoch.find('.pth')]
            save_name = os.path.join(
                self.config.save_dir,
                '{}_gen_graphs_epoch_{}_block_{}_stride_{}.png'.format(
                    self.config.test.test_model_name[:-4], test_epoch,
                    self.block_size, self.stride))

            # remove isolated nodes for better visulization
            graphs_pred_vis = [
                copy.deepcopy(gg) for gg in graphs_gen[:self.num_vis]
            ]

            if self.better_vis:
                for gg in graphs_pred_vis:
                    gg.remove_nodes_from(list(nx.isolates(gg)))

            # display the largest connected component for better visualization
            vis_graphs = []
            for gg in graphs_pred_vis:
                CGs = [gg.subgraph(c) for c in nx.connected_components(gg)]
                CGs = sorted(CGs,
                             key=lambda x: x.number_of_nodes(),
                             reverse=True)
                vis_graphs += [CGs[0]]

            if self.is_single_plot:
                draw_graph_list(vis_graphs,
                                num_row,
                                num_col,
                                fname=save_name,
                                layout='spring')
            else:
                draw_graph_list_separate(vis_graphs,
                                         fname=save_name[:-4],
                                         is_single=True,
                                         layout='spring')

            save_name = os.path.join(self.config.save_dir, 'train_graphs.png')

            if self.is_single_plot:
                draw_graph_list(self.graphs_train[:self.num_vis],
                                num_row,
                                num_col,
                                fname=save_name,
                                layout='spring')
            else:
                draw_graph_list_separate(self.graphs_train[:self.num_vis],
                                         fname=save_name[:-4],
                                         is_single=True,
                                         layout='spring')

        ### Evaluation
        if self.config.dataset.name in ['lobster']:
            acc = eval_acc_lobster_graph(graphs_gen)
            logger.info(
                'Validity accuracy of generated graphs = {}'.format(acc))

        num_nodes_gen = [len(aa) for aa in graphs_gen]

        # Compared with Validation Set
        num_nodes_dev = [len(gg.nodes)
                         for gg in self.graphs_dev]  # shape B X 1
        mmd_degree_dev, mmd_clustering_dev, mmd_4orbits_dev, mmd_spectral_dev = evaluate(
            self.graphs_dev, graphs_gen, degree_only=False)
        mmd_num_nodes_dev = compute_mmd([np.bincount(num_nodes_dev)],
                                        [np.bincount(num_nodes_gen)],
                                        kernel=gaussian_emd)

        # Compared with Test Set
        num_nodes_test = [len(gg.nodes)
                          for gg in self.graphs_test]  # shape B X 1
        mmd_degree_test, mmd_clustering_test, mmd_4orbits_test, mmd_spectral_test = evaluate(
            self.graphs_test, graphs_gen, degree_only=False)
        mmd_num_nodes_test = compute_mmd([np.bincount(num_nodes_test)],
                                         [np.bincount(num_nodes_gen)],
                                         kernel=gaussian_emd)

        logger.info(
            "Validation MMD scores of #nodes/degree/clustering/4orbits/spectral are = {}/{}/{}/{}/{}"
            .format(mmd_num_nodes_dev, mmd_degree_dev, mmd_clustering_dev,
                    mmd_4orbits_dev, mmd_spectral_dev))
        logger.info(
            "Test MMD scores of #nodes/degree/clustering/4orbits/spectral are = {}/{}/{}/{}/{}"
            .format(mmd_num_nodes_test, mmd_degree_test, mmd_clustering_test,
                    mmd_4orbits_test, mmd_spectral_test))

        if self.config.dataset.name in ['lobster']:
            return mmd_degree_dev, mmd_clustering_dev, mmd_4orbits_dev, mmd_spectral_dev, mmd_degree_test, mmd_clustering_test, mmd_4orbits_test, mmd_spectral_test, acc
        else:
            return mmd_degree_dev, mmd_clustering_dev, mmd_4orbits_dev, mmd_spectral_dev, mmd_degree_test, mmd_clustering_test, mmd_4orbits_test, mmd_spectral_test