Beispiel #1
0
    def test(self):
        self.config.save_dir = self.test_conf.test_model_dir

        ### Compute Erdos-Renyi baseline
        # if self.config.test.is_test_ER:
        p_ER = sum([aa.number_of_edges()
                    for aa in self.graphs_train]) / sum([aa.number_of_nodes()**2 for aa in self.graphs_train])
        # graphs_baseline = [nx.fast_gnp_random_graph(self.max_num_nodes, p_ER, seed=ii) for ii in range(self.num_test_gen)]
        graphs_gen = [nx.fast_gnp_random_graph(self.max_num_nodes, p_ER, seed=ii) for ii in range(self.num_test_gen)]
        temp = []
        for G in graphs_gen:
            G.remove_nodes_from(list(nx.isolates(G)))
            if G is not None:
                #  take the largest connected component
                CGs = [G.subgraph(c) for c in nx.connected_components(G)]
                CGs = sorted(CGs, key=lambda x: x.number_of_nodes(), reverse=True)
                temp.append(CGs[0])
        # graphs_gen = temp
        graphs_baseline = temp

        # else:
        ### load model
        model = eval(self.model_conf.name)(self.config)
        model_file = os.path.join(self.config.save_dir, self.test_conf.test_model_name)
        load_model(model, model_file, self.device)

        if self.use_gpu:
            model = nn.DataParallel(model, device_ids=self.gpus).to(self.device)

        model.eval()

        ### Generate Graphs
        A_pred = []
        node_label_pred = []
        num_nodes_pred = []
        num_test_batch = int(np.ceil(self.num_test_gen / self.test_conf.batch_size))

        gen_run_time = []
        for ii in tqdm(range(num_test_batch)):
            with torch.no_grad():
                start_time = time.time()
                input_dict = {}
                input_dict['is_sampling'] = True
                input_dict['batch_size'] = self.test_conf.batch_size
                input_dict['num_nodes_pmf'] = self.num_nodes_pmf_train
                A_tmp, node_label_tmp = model(input_dict)
                gen_run_time += [time.time() - start_time]
                A_pred += [aa.data.cpu().numpy() for aa in A_tmp]
                node_label_pred += [ll.data.cpu().numpy() for ll in node_label_tmp]
                num_nodes_pred += [aa.shape[0] for aa in A_tmp]
        # print(len(A_pred), type(A_pred[0]))

        logger.info('Average test time per mini-batch = {}'.format(np.mean(gen_run_time)))

        # print(A_pred[0].shape,
        #       get_graph(A_pred[0]).number_of_nodes(),
        #       get_graph_with_labels(A_pred[0], node_label_pred[0]).number_of_nodes())
        # print(A_pred[0])
        # return
        # graphs_gen = [get_graph(aa) for aa in A_pred]
        graphs_gen = [get_graph_with_labels(aa, ll) for aa, ll in zip(A_pred, node_label_pred)]
        valid_pctg, bipartite_pctg = calculate_validity(graphs_gen)  # for adding bipartite graph attribute

        # return

        ### Visualize Generated Graphs
        if self.is_vis:
            num_col = self.vis_num_row
            num_row = int(np.ceil(self.num_vis / num_col))
            test_epoch = self.test_conf.test_model_name
            test_epoch = test_epoch[test_epoch.rfind('_') + 1:test_epoch.find('.pth')]
            save_name = os.path.join(self.config.save_dir, '{}_gen_graphs_epoch_{}_block_{}_stride_{}.png'.format(
                self.config.test.test_model_name[:-4], test_epoch, self.block_size, self.stride))

            # remove isolated nodes for better visulization
            # graphs_pred_vis = [copy.deepcopy(gg) for gg in graphs_gen[:self.num_vis]]
            graphs_pred_vis = [copy.deepcopy(gg) for gg in graphs_gen if not gg.graph['bipartite']]
            logger.info('Number of not bipartite graphs: {} / {}'.format(len(graphs_pred_vis), len(graphs_gen)))
            # if self.better_vis:
            #     for gg in graphs_pred_vis:
            #         gg.remove_nodes_from(list(nx.isolates(gg)))

            # # display the largest connected component for better visualization
            # vis_graphs = []
            # for gg in graphs_pred_vis:
            #     CGs = [gg.subgraph(c) for c in nx.connected_components(gg)] # nx.subgraph makes a graph frozen!
            #     CGs = sorted(CGs, key=lambda x: x.number_of_nodes(), reverse=True)
            #     vis_graphs += [CGs[0]]
            vis_graphs = graphs_pred_vis

            if self.is_single_plot:
                draw_graph_list(vis_graphs, num_row, num_col, fname=save_name, layout='spring')
            else:  #XD: using this for now
                draw_graph_list_separate(vis_graphs, fname=save_name[:-4], is_single=True, layout='spring')

            save_name = os.path.join(self.config.save_dir, 'train_graphs.png')
            if self.is_single_plot:
                draw_graph_list(self.graphs_train[:self.num_vis], num_row, num_col, fname=save_name, layout='spring')
                print('training single plot saved at:', save_name)
            else:  #XD: using this for now
                graph_list_train = [get_graph_from_nx(G) for G in self.graphs_train[:self.num_vis]]
                draw_graph_list_separate(graph_list_train, fname=save_name[:-4], is_single=True, layout='spring')
                print('training plots saved individually at:', save_name[:-4])
        return

        ### Evaluation
        if self.config.dataset.name in ['lobster']:
            acc = eval_acc_lobster_graph(graphs_gen)
            logger.info('Validity accuracy of generated graphs = {}'.format(acc))
        '''=====XD====='''
        ## graphs_gen = [generate_random_baseline_single(len(aa)) for aa in graphs_gen]  # use this line for random baseline MMD scores. Remember to comment it later!
        # draw_hists(self.graphs_test, graphs_baseline, graphs_gen)
        valid_pctg, bipartite_pctg = calculate_validity(graphs_gen)
        # logger.info('Generated {} graphs, valid percentage = {:.2f}, bipartite percentage = {:.2f}'.format(
        #     len(graphs_gen), valid_pctg, bipartite_pctg))
        # # return
        '''=====XD====='''

        num_nodes_gen = [len(aa) for aa in graphs_gen]

        # # Compared with Validation Set
        # num_nodes_dev = [len(gg.nodes) for gg in self.graphs_dev]  # shape B X 1
        # mmd_degree_dev, mmd_clustering_dev, mmd_4orbits_dev, mmd_spectral_dev, mmd_mean_degree_dev, mmd_max_degree_dev, mmd_mean_centrality_dev, mmd_assortativity_dev, mmd_mean_degree_connectivity_dev = evaluate(self.graphs_dev, graphs_gen, degree_only=False)
        # mmd_num_nodes_dev = compute_mmd([np.bincount(num_nodes_dev)], [np.bincount(num_nodes_gen)], kernel=gaussian_emd)
        # logger.info("Validation MMD scores of #nodes/degree/clustering/4orbits/spectral/... are = {:.5f}/{:.5f}/{:.5f}/{:.5f}/{:.5f}/{:.5f}/{:.5f}/{:.5f}/{:.5f}/{:.5f}".format(mmd_num_nodes_dev, mmd_degree_dev, mmd_clustering_dev, mmd_4orbits_dev, mmd_spectral_dev, mmd_mean_degree_dev, mmd_max_degree_dev, mmd_mean_centrality_dev, mmd_assortativity_dev, mmd_mean_degree_connectivity_dev))

        # Compared with Test Set
        num_nodes_test = [len(gg.nodes) for gg in self.graphs_test]  # shape B X 1
        mmd_degree_test, mmd_clustering_test, mmd_4orbits_test, mmd_spectral_test, mmd_mean_degree_test, mmd_max_degree_test, mmd_mean_centrality_test, mmd_assortativity_test, mmd_mean_degree_connectivity_test = evaluate(
            self.graphs_test, graphs_gen, degree_only=False)
        mmd_num_nodes_test = compute_mmd(
            [np.bincount(num_nodes_test)], [np.bincount(num_nodes_gen)], kernel=gaussian_emd)

        logger.info(
            "Test MMD scores of #nodes/degree/clustering/4orbits/spectral/... are = {:.5f}/{:.5f}/{:.5f}/{:.5f}/{:.5f}/{:.5f}/{:.5f}/{:.5f}/{:.5f}/{:.5f}".
            format(mmd_num_nodes_test, mmd_degree_test, mmd_clustering_test, mmd_4orbits_test, mmd_spectral_test,
                   mmd_mean_degree_test, mmd_max_degree_test, mmd_mean_centrality_test, mmd_assortativity_test,
                   mmd_mean_degree_connectivity_test))
Beispiel #2
0
    def test(self):
        self.config.save_dir = self.test_conf.test_model_dir

        ### Compute Erdos-Renyi baseline
        if self.config.test.is_test_ER:
            p_ER = sum([
                aa.number_of_edges() for aa in self.graphs_train
            ]) / sum([aa.number_of_nodes()**2 for aa in self.graphs_train])
            graphs_gen = [
                nx.fast_gnp_random_graph(self.max_num_nodes, p_ER, seed=ii)
                for ii in range(self.num_test_gen)
            ]
        else:
            ### load model
            model = eval(self.model_conf.name)(self.config)
            model_file = os.path.join(self.config.save_dir,
                                      self.test_conf.test_model_name)
            load_model(model, model_file, self.device)

            if self.use_gpu:
                model = nn.DataParallel(model,
                                        device_ids=self.gpus).to(self.device)

            model.eval()

            ### Generate Graphs
            A_pred = []
            num_nodes_pred = []
            alpha_list = []
            num_test_batch = int(
                np.ceil(self.num_test_gen / self.test_conf.batch_size))

            gen_run_time = []
            for ii in tqdm(range(num_test_batch)):
                with torch.no_grad():
                    start_time = time.time()
                    input_dict = {}
                    input_dict['is_sampling'] = True
                    input_dict['batch_size'] = self.test_conf.batch_size
                    input_dict['num_nodes_pmf'] = self.num_nodes_pmf_train
                    A_tmp, alpha_temp = model(input_dict)
                    gen_run_time += [time.time() - start_time]
                    A_pred += [aa.data.cpu().numpy() for aa in A_tmp]
                    num_nodes_pred += [aa.shape[0] for aa in A_tmp]
                    alpha_list += [aa.data.cpu().numpy() for aa in alpha_temp]

            logger.info('Average test time per mini-batch = {}'.format(
                np.mean(gen_run_time)))

            graphs_gen = [
                get_graph(aa, alpha_list[i]) for i, aa in enumerate(A_pred)
            ]

        ### Visualize Generated Graphs
        if self.is_vis:
            num_col = self.vis_num_row
            num_row = int(np.ceil(self.num_vis / num_col))
            test_epoch = self.test_conf.test_model_name
            test_epoch = test_epoch[test_epoch.rfind('_') +
                                    1:test_epoch.find('.pth')]
            save_name = os.path.join(
                self.config.save_dir,
                '{}_gen_graphs_epoch_{}_block_{}_stride_{}.png'.format(
                    self.config.test.test_model_name[:-4], test_epoch,
                    self.block_size, self.stride))

            # remove isolated nodes for better visulization
            graphs_pred_vis = [
                copy.deepcopy(gg) for gg in graphs_gen[:self.num_vis]
            ]

            # Saves Graphs
            for i, gg in enumerate(graphs_pred_vis):
                G = gg
                name = os.path.join(
                    self.config.save_dir,
                    '{}_gen_graphs_epoch_{}_{}.pickle'.format(
                        self.config.test.test_model_name[:-4], test_epoch, i))
                with open(name, 'wb') as handle:
                    pickle.dump(G, handle)

            if self.better_vis:
                for gg in graphs_pred_vis:
                    gg.remove_nodes_from(list(nx.isolates(gg)))

            # display the largest connected component for better visualization
            vis_graphs = []
            for gg in graphs_pred_vis:
                CGs = [gg.subgraph(c) for c in nx.connected_components(gg)]
                CGs = sorted(CGs,
                             key=lambda x: x.number_of_nodes(),
                             reverse=True)
                vis_graphs += [CGs[0]]

            if self.is_single_plot:
                draw_graph_list(vis_graphs,
                                num_row,
                                num_col,
                                fname=save_name,
                                layout='spring')
            else:
                draw_graph_list_separate(vis_graphs,
                                         fname=save_name[:-4],
                                         is_single=True,
                                         layout='spring')

            save_name = os.path.join(self.config.save_dir, 'train_graphs.png')

            if self.is_single_plot:
                draw_graph_list(self.graphs_train[:self.num_vis],
                                num_row,
                                num_col,
                                fname=save_name,
                                layout='spring')
            else:
                draw_graph_list_separate(self.graphs_train[:self.num_vis],
                                         fname=save_name[:-4],
                                         is_single=True,
                                         layout='spring')

        ### Evaluation
        if self.config.dataset.name in ['lobster']:
            acc = eval_acc_lobster_graph(graphs_gen)
            logger.info(
                'Validity accuracy of generated graphs = {}'.format(acc))

        num_nodes_gen = [len(aa) for aa in graphs_gen]

        # Compared with Validation Set
        num_nodes_dev = [len(gg.nodes)
                         for gg in self.graphs_dev]  # shape B X 1
        mmd_degree_dev, mmd_clustering_dev, mmd_4orbits_dev, mmd_spectral_dev = evaluate(
            self.graphs_dev, graphs_gen, degree_only=False)
        mmd_num_nodes_dev = compute_mmd([np.bincount(num_nodes_dev)],
                                        [np.bincount(num_nodes_gen)],
                                        kernel=gaussian_emd)

        # Compared with Test Set
        num_nodes_test = [len(gg.nodes)
                          for gg in self.graphs_test]  # shape B X 1
        mmd_degree_test, mmd_clustering_test, mmd_4orbits_test, mmd_spectral_test = evaluate(
            self.graphs_test, graphs_gen, degree_only=False)
        mmd_num_nodes_test = compute_mmd([np.bincount(num_nodes_test)],
                                         [np.bincount(num_nodes_gen)],
                                         kernel=gaussian_emd)

        logger.info(
            "Validation MMD scores of #nodes/degree/clustering/4orbits/spectral are = {}/{}/{}/{}/{}"
            .format(mmd_num_nodes_dev, mmd_degree_dev, mmd_clustering_dev,
                    mmd_4orbits_dev, mmd_spectral_dev))
        logger.info(
            "Test MMD scores of #nodes/degree/clustering/4orbits/spectral are = {}/{}/{}/{}/{}"
            .format(mmd_num_nodes_test, mmd_degree_test, mmd_clustering_test,
                    mmd_4orbits_test, mmd_spectral_test))

        if self.config.dataset.name in ['lobster']:
            return mmd_degree_dev, mmd_clustering_dev, mmd_4orbits_dev, mmd_spectral_dev, mmd_degree_test, mmd_clustering_test, mmd_4orbits_test, mmd_spectral_test, acc
        else:
            return mmd_degree_dev, mmd_clustering_dev, mmd_4orbits_dev, mmd_spectral_dev, mmd_degree_test, mmd_clustering_test, mmd_4orbits_test, mmd_spectral_test
  def test(self):
    self.config.save_dir_train = self.test_conf.test_model_dir

    if not self.config.test.is_test_ER:
      ### load model
      model = eval(self.model_conf.name)(self.config)
      model_file = os.path.join(self.config.save_dir_train, self.test_conf.test_model_name)
      load_model(model, model_file, self.device)
      if self.use_gpu:
        model = nn.DataParallel(model, device_ids=self.gpus).to(self.device)
      model.eval()

      if hasattr(self.config, 'complete_graph_model'):
        complete_graph_model = eval(self.config.complete_graph_model.name)(self.config.complete_graph_model)
        complete_graph_model_file = os.path.join(self.config.complete_graph_model.test_model_dir,
                                                 self.config.complete_graph_model.test_model_name)
        load_model(complete_graph_model, complete_graph_model_file, self.device)
        if self.use_gpu:
          complete_graph_model = nn.DataParallel(complete_graph_model, device_ids=self.gpus).to(self.device)
        complete_graph_model.eval()

    if self.config.test.is_test_ER or not hasattr(self.config.test, 'hard_multi') or not self.config.test.hard_multi:
      hard_thre_list = [None]
    else:
      hard_thre_list = np.arange(0.5, 1, 0.1)

    for test_hard_idx, hard_thre in enumerate(hard_thre_list):
      if self.config.test.is_test_ER:
        ### Compute Erdos-Renyi baseline
        p_ER = sum([aa.number_of_edges() for aa in self.graphs_train]) / sum([aa.number_of_nodes() ** 2 for aa in self.graphs_train])
        graphs_gen = [nx.fast_gnp_random_graph(self.max_num_nodes, p_ER, seed=ii) for ii in range(self.num_test_gen)]
      else:
        logger.info('Test pass {}. Hard threshold {}'.format(test_hard_idx, hard_thre))
        ### Generate Graphs
        A_pred = []
        num_nodes_pred = []
        num_test_batch = int(np.ceil(self.num_test_gen / self.test_conf.batch_size))

        gen_run_time = []
        for ii in tqdm(range(num_test_batch)):
          with torch.no_grad():
            start_time = time.time()
            input_dict = {}
            input_dict['is_sampling'] = True
            input_dict['batch_size'] = self.test_conf.batch_size
            input_dict['num_nodes_pmf'] = self.num_nodes_pmf_train
            input_dict['hard_thre'] = hard_thre
            A_tmp = model(input_dict)

            if hasattr(self.config, 'complete_graph_model'):
              final_A_list = []
              for batch_idx in range(len(A_tmp)):
                new_pmf = torch.zeros(len(self.num_nodes_pmf_train))
                max_prob = 0.
                max_prob_num_nodes = None
                for num_nodes, prob in enumerate(self.num_nodes_pmf_train):
                  if prob == 0.:
                    continue
                  tmp_data = {}
                  A_tmp_tmp = A_tmp[batch_idx][:num_nodes, :num_nodes]
                  tmp_data['adj'] = F.pad(
                    A_tmp_tmp,
                    (0, self.config.complete_graph_model.model.max_num_nodes-num_nodes, 0, 0),
                    'constant', value=.0)[None, None, ...]

                  adj = torch.tril(A_tmp_tmp, diagonal=-1)
                  adj = adj + adj.transpose(0, 1)
                  edges = adj.to_sparse().coalesce().indices()
                  tmp_data['edges'] = edges.t()
                  tmp_data['subgraph_idx'] = torch.zeros(num_nodes).long().to(self.device, non_blocking=True)

                  tmp_logit = complete_graph_model(tmp_data)
                  new_pmf[num_nodes] = torch.sigmoid(tmp_logit).item()

                  if new_pmf[num_nodes] > max_prob:
                    max_prob = new_pmf[num_nodes]
                    max_prob_num_nodes = num_nodes

                  if new_pmf[num_nodes] <= 0.9:
                    new_pmf[num_nodes] = 0.

                if (new_pmf == 0.).all():
                  logger.info('(new_pmf == 0.).all(), use {} nodes with max prob {}'.format(max_prob_num_nodes, max_prob))
                  final_num_nodes = max_prob_num_nodes
                else:
                  final_num_nodes = torch.multinomial(new_pmf, 1).item()
                final_A_list.append(
                  A_tmp_tmp[:final_num_nodes, :final_num_nodes]
                )
              A_tmp = final_A_list
            gen_run_time += [time.time() - start_time]
            A_pred += [aa.cpu().numpy() for aa in A_tmp]
            num_nodes_pred += [aa.shape[0] for aa in A_tmp]
        print('num_nodes_pred', num_nodes_pred)
        logger.info('Average test time per mini-batch = {}'.format(
          np.mean(gen_run_time)))

        graphs_gen = [get_graph(aa) for aa in A_pred]

      ### Visualize Generated Graphs
      if self.is_vis:
        num_col = self.vis_num_row
        num_row = self.num_vis // num_col
        test_epoch = self.test_conf.test_model_name
        test_epoch = test_epoch[test_epoch.rfind('_') + 1:test_epoch.find('.pth')]
        if hard_thre is not None:
          save_name = os.path.join(self.config.save_dir_train, '{}_gen_graphs_epoch_{}_hard_{}.png'.format(
            self.config.test.test_model_name[:-4], test_epoch,
            int(round(hard_thre*10))))
          save_name2 = os.path.join(self.config.save_dir,
                                   '{}_gen_graphs_epoch_{}_hard_{}.png'.format(
                                     self.config.test.test_model_name[:-4], test_epoch,
                                     int(round(hard_thre * 10))))
        else:
          save_name = os.path.join(self.config.save_dir_train,
                                   '{}_gen_graphs_epoch_{}.png'.format(
                                     self.config.test.test_model_name[:-4], test_epoch))
          save_name2 = os.path.join(self.config.save_dir,
                                    '{}_gen_graphs_epoch_{}.png'.format(
                                      self.config.test.test_model_name[:-4], test_epoch))

        # remove isolated nodes for better visulization
        graphs_pred_vis = [copy.deepcopy(gg) for gg in graphs_gen[:self.num_vis]]

        if self.better_vis:
          # actually not necessary with the following largest connected component selection
          for gg in graphs_pred_vis:
            gg.remove_nodes_from(list(nx.isolates(gg)))

        # display the largest connected component for better visualization
        vis_graphs = []
        for gg in graphs_pred_vis:
          if self.better_vis:
            CGs = [gg.subgraph(c) for c in nx.connected_components(gg)]
            CGs = sorted(CGs, key=lambda x: x.number_of_nodes(), reverse=True)
            vis_graphs += [CGs[0]]
          else:
            vis_graphs += [gg]
        print('number of nodes after better vis', [tmp_g.number_of_nodes() for tmp_g in vis_graphs])

        if self.is_single_plot:
          # draw_graph_list(vis_graphs, num_row, num_col, fname=save_name, layout='spring')
          draw_graph_list(vis_graphs, num_row, num_col, fname=save_name2, layout='spring')
        else:
          # draw_graph_list_separate(vis_graphs, fname=save_name[:-4], is_single=True, layout='spring')
          draw_graph_list_separate(vis_graphs, fname=save_name2[:-4], is_single=True, layout='spring')

        if test_hard_idx == 0:
          save_name = os.path.join(self.config.save_dir_train, 'train_graphs.png')

          if self.is_single_plot:
            draw_graph_list(
              self.graphs_train[:self.num_vis],
              num_row,
              num_col,
              fname=save_name,
              layout='spring')
          else:
            draw_graph_list_separate(
              self.graphs_train[:self.num_vis],
              fname=save_name[:-4],
              is_single=True,
              layout='spring')

      ### Evaluation
      if self.config.dataset.name in ['lobster']:
        acc = eval_acc_lobster_graph(graphs_gen)
        logger.info('Validity accuracy of generated graphs = {}'.format(acc))

      num_nodes_gen = [len(aa) for aa in graphs_gen]

      # Compared with Validation Set
      num_nodes_dev = [len(gg.nodes) for gg in self.graphs_dev]  # shape B X 1
      mmd_degree_dev, mmd_clustering_dev, mmd_4orbits_dev, mmd_spectral_dev = evaluate(self.graphs_dev, graphs_gen,
                                                                                       degree_only=False)
      mmd_num_nodes_dev = compute_mmd([np.bincount(num_nodes_dev)], [np.bincount(num_nodes_gen)], kernel=gaussian_emd)

      # Compared with Test Set
      num_nodes_test = [len(gg.nodes) for gg in self.graphs_test]  # shape B X 1
      mmd_degree_test, mmd_clustering_test, mmd_4orbits_test, mmd_spectral_test = evaluate(self.graphs_test, graphs_gen,
                                                                                           degree_only=False)
      mmd_num_nodes_test = compute_mmd([np.bincount(num_nodes_test)], [np.bincount(num_nodes_gen)], kernel=gaussian_emd)

      logger.info(
        "Validation MMD scores of #nodes/degree/clustering/4orbits/spectral are = {:.4E}/{:.4E}/{:.4E}/{:.4E}/{:.4E}".format(Decimal(mmd_num_nodes_dev),
                                                                                                   Decimal(mmd_degree_dev),
                                                                                                   Decimal(mmd_clustering_dev),
                                                                                                   Decimal(mmd_4orbits_dev),
                                                                                                   Decimal(mmd_spectral_dev)))
      logger.info(
        "Test MMD scores of #nodes/degree/clustering/4orbits/spectral are = {:.4E}/{:.4E}/{:.4E}/{:.4E}/{:.4E}".format(Decimal(mmd_num_nodes_test),
                                                                                                   Decimal(mmd_degree_test),
                                                                                                   Decimal(mmd_clustering_test),
                                                                                                   Decimal(mmd_4orbits_test),
                                                                                                   Decimal(mmd_spectral_test)))
Beispiel #4
0
    def test(self):
        self.config.save_dir_train = self.test_conf.test_model_dir

        ### test dataset
        test_dataset = eval(self.dataset_conf.loader_name)(self.config,
                                                           self.graphs_test,
                                                           tag='test')
        test_loader = torch.utils.data.DataLoader(
            test_dataset,
            batch_size=self.test_conf.batch_size,
            shuffle=False,
            num_workers=self.train_conf.num_workers,
            collate_fn=test_dataset.collate_fn,
            drop_last=False)

        ### load model
        args = self.config.model
        n_labels = self.dataset_conf.max_m + self.dataset_conf.max_n
        G = define_G(args.nz, args.ngf, args.netG, args.final_activation,
                     args.norm_G)
        model_file_G = os.path.join(self.config.save_dir_train,
                                    self.test_conf.test_model_name)

        load_model(G, model_file_G, self.device)
        if self.use_gpu:
            G = G.cuda()  #nn.DataParallel(G).to(self.device)
        G.train()

        if not hasattr(self.config.test,
                       'hard_multi') or not self.config.test.hard_multi:
            hard_thre_list = [None]
        else:
            hard_thre_list = np.arange(0.5, 1, 0.1)

        for test_hard_idx, hard_thre in enumerate(hard_thre_list):
            logger.info('Test pass {}. Hard threshold {}'.format(
                test_hard_idx, hard_thre))
            ### Generate Graphs
            A_pred = []
            gen_run_time = []

            for batch_data in test_loader:
                # asserted in arg helper
                ff = 0

                with torch.no_grad():
                    data = {}
                    data['adj'] = batch_data[ff]['adj'].pin_memory().to(
                        self.config.device, non_blocking=True)
                    data['m'] = batch_data[ff]['m'].to(self.config.device,
                                                       non_blocking=True)
                    data['n'] = batch_data[ff]['n'].to(self.config.device,
                                                       non_blocking=True)

                    batch_size = data['adj'].size(0)

                    i_onehot = torch.zeros(
                        (batch_size, self.dataset_conf.max_m),
                        requires_grad=True).pin_memory().to(self.config.device,
                                                            non_blocking=True)
                    i_onehot.scatter_(1, data['m'][:, None] - 1, 1)
                    j_onehot = torch.zeros(
                        (batch_size, self.dataset_conf.max_n),
                        requires_grad=True).pin_memory().to(self.config.device,
                                                            non_blocking=True)
                    j_onehot.scatter_(1, data['n'][:, None] - 1, 1)
                    y_onehot = torch.cat((i_onehot, j_onehot), dim=1)

                    if args.nz > n_labels:
                        noise = torch.randn(
                            (batch_size, args.nz - n_labels, 1, 1),
                            requires_grad=True).to(self.config.device,
                                                   non_blocking=True)
                        z_input = torch.cat(
                            (y_onehot.view(batch_size, n_labels, 1, 1), noise),
                            dim=1)
                    else:
                        z_input = y_onehot.view(batch_size, n_labels, 1, 1)

                    start_time = time.time()
                    output = G(z_input).squeeze(1)  # (B, 1, n, n)
                    if self.model_conf.final_activation == 'tanh':
                        output = (output + 1) / 2
                    if self.model_conf.is_sym:
                        output = torch.tril(output, diagonal=-1)
                        output = output + output.transpose(1, 2)
                    gen_run_time += [time.time() - start_time]

                    if hard_thre is not None:
                        A_pred += [(output[batch_idx, ...] >
                                    hard_thre).long().cpu().numpy()
                                   for batch_idx in range(batch_size)]
                    else:
                        A_pred += [
                            torch.bernoulli(output[batch_idx,
                                                   ...]).long().cpu().numpy()
                            for batch_idx in range(batch_size)
                        ]

            logger.info('Average test time per mini-batch = {}'.format(
                np.mean(gen_run_time)))

            graphs_gen = [get_graph(aa) for aa in A_pred]

            ### Visualize Generated Graphs
            if self.is_vis:
                num_col = self.vis_num_row
                num_row = self.num_vis // num_col
                test_epoch = self.test_conf.test_model_name
                test_epoch = test_epoch[test_epoch.rfind('_') +
                                        1:test_epoch.find('.pth')]
                if hard_thre is not None:
                    save_name = os.path.join(
                        self.config.save_dir_train,
                        '{}_gen_graphs_epoch_{}_hard_{}.png'.format(
                            self.config.test.test_model_name[:-4], test_epoch,
                            int(round(hard_thre * 10))))
                    save_name2 = os.path.join(
                        self.config.save_dir,
                        '{}_gen_graphs_epoch_{}_hard_{}.png'.format(
                            self.config.test.test_model_name[:-4], test_epoch,
                            int(round(hard_thre * 10))))
                else:
                    save_name = os.path.join(
                        self.config.save_dir_train,
                        '{}_gen_graphs_epoch_{}.png'.format(
                            self.config.test.test_model_name[:-4], test_epoch))
                    save_name2 = os.path.join(
                        self.config.save_dir,
                        '{}_gen_graphs_epoch_{}.png'.format(
                            self.config.test.test_model_name[:-4], test_epoch))

                # remove isolated nodes for better visulization
                graphs_pred_vis = [
                    copy.deepcopy(gg) for gg in graphs_gen[:self.num_vis]
                ]

                if self.better_vis:
                    # actually not necessary with the following largest connected component selection
                    for gg in graphs_pred_vis:
                        gg.remove_nodes_from(list(nx.isolates(gg)))

                # display the largest connected component for better visualization
                vis_graphs = []
                for gg in graphs_pred_vis:
                    if self.better_vis:
                        CGs = [
                            gg.subgraph(c) for c in nx.connected_components(gg)
                        ]
                        CGs = sorted(CGs,
                                     key=lambda x: x.number_of_nodes(),
                                     reverse=True)
                        vis_graphs += [CGs[0]]
                    else:
                        vis_graphs += [gg]
                print('number of nodes after better vis',
                      [tmp_g.number_of_nodes() for tmp_g in vis_graphs])

                if self.is_single_plot:
                    # draw_graph_list(vis_graphs, num_row, num_col, fname=save_name, layout='spring')
                    draw_graph_list(vis_graphs,
                                    num_row,
                                    num_col,
                                    fname=save_name2,
                                    layout='spring')
                else:
                    # draw_graph_list_separate(vis_graphs, fname=save_name[:-4], is_single=True, layout='spring')
                    draw_graph_list_separate(vis_graphs,
                                             fname=save_name2[:-4],
                                             is_single=True,
                                             layout='spring')

                if test_hard_idx == 0:
                    save_name = os.path.join(self.config.save_dir_train,
                                             'train_graphs.png')

                    if self.is_single_plot:
                        draw_graph_list(self.graphs_train[:self.num_vis],
                                        num_row,
                                        num_col,
                                        fname=save_name,
                                        layout='spring')
                    else:
                        draw_graph_list_separate(
                            self.graphs_train[:self.num_vis],
                            fname=save_name[:-4],
                            is_single=True,
                            layout='spring')

            ### Evaluation
            if self.config.dataset.name in ['lobster']:
                acc = eval_acc_lobster_graph(graphs_gen)
                logger.info(
                    'Validity accuracy of generated graphs = {}'.format(acc))

            num_nodes_gen = [len(aa) for aa in graphs_gen]

            # Compared with Validation Set
            num_nodes_dev = [len(gg.nodes)
                             for gg in self.graphs_dev]  # shape B X 1
            mmd_degree_dev, mmd_clustering_dev, mmd_4orbits_dev, mmd_spectral_dev = evaluate(
                self.graphs_dev, graphs_gen, degree_only=False)
            mmd_num_nodes_dev = compute_mmd([np.bincount(num_nodes_dev)],
                                            [np.bincount(num_nodes_gen)],
                                            kernel=gaussian_emd)

            # Compared with Test Set
            num_nodes_test = [len(gg.nodes)
                              for gg in self.graphs_test]  # shape B X 1
            mmd_degree_test, mmd_clustering_test, mmd_4orbits_test, mmd_spectral_test = evaluate(
                self.graphs_test, graphs_gen, degree_only=False)
            mmd_num_nodes_test = compute_mmd([np.bincount(num_nodes_test)],
                                             [np.bincount(num_nodes_gen)],
                                             kernel=gaussian_emd)

            logger.info(
                "Validation MMD scores of #nodes/degree/clustering/4orbits/spectral are = {:.4E}/{:.4E}/{:.4E}/{:.4E}/{:.4E}"
                .format(Decimal(mmd_num_nodes_dev), Decimal(mmd_degree_dev),
                        Decimal(mmd_clustering_dev), Decimal(mmd_4orbits_dev),
                        Decimal(mmd_spectral_dev)))
            logger.info(
                "Test MMD scores of #nodes/degree/clustering/4orbits/spectral are = {:.4E}/{:.4E}/{:.4E}/{:.4E}/{:.4E}"
                .format(Decimal(mmd_num_nodes_test), Decimal(mmd_degree_test),
                        Decimal(mmd_clustering_test),
                        Decimal(mmd_4orbits_test), Decimal(mmd_spectral_test)))
Beispiel #5
0
    def train(self):
        ### create data loader
        train_dataset = eval(self.dataset_conf.loader_name)(self.config,
                                                            self.graphs_train,
                                                            tag='train')
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=self.train_conf.batch_size,
            shuffle=self.train_conf.shuffle,
            num_workers=self.train_conf.num_workers,
            collate_fn=train_dataset.collate_fn,
            drop_last=False)

        # create models
        model = eval(self.model_conf.name)(self.config)
        print('number of parameters : {}'.format(
            sum([np.prod(x.shape) for x in model.parameters()])))

        if self.use_gpu:
            model = DataParallel(model, device_ids=self.gpus).to(self.device)

        # create optimizer
        params = filter(lambda p: p.requires_grad, model.parameters())
        if self.train_conf.optimizer == 'SGD':
            optimizer = optim.SGD(params,
                                  lr=self.train_conf.lr,
                                  momentum=self.train_conf.momentum,
                                  weight_decay=self.train_conf.wd)
        elif self.train_conf.optimizer == 'Adam':
            optimizer = optim.Adam(params,
                                   lr=self.train_conf.lr,
                                   weight_decay=self.train_conf.wd)
        else:
            raise ValueError("Non-supported optimizer!")

        early_stop = EarlyStopper([0.0], win_size=100, is_decrease=False)

        from copy import deepcopy
        lr_scheduler = optim.lr_scheduler.MultiStepLR(
            deepcopy(optimizer),
            milestones=self.train_conf.lr_decay_epoch,
            gamma=self.train_conf.lr_decay)

        # reset gradient
        optimizer.zero_grad()

        # resume training
        resume_epoch = 0
        if self.train_conf.is_resume:
            model_file = os.path.join(self.train_conf.resume_dir,
                                      self.train_conf.resume_model)
            load_model(model.module if self.use_gpu else model,
                       model_file,
                       self.device,
                       optimizer=optimizer,
                       scheduler=lr_scheduler)
            resume_epoch = self.train_conf.resume_epoch

        # Training Loop
        iter_count = 0
        results = defaultdict(list)
        for epoch in range(resume_epoch, self.train_conf.max_epoch):
            has_sampled = False
            model.train()
            # lr_scheduler.step()
            train_iterator = train_loader.__iter__()

            for inner_iter in range(len(train_loader) // self.num_gpus):
                optimizer.zero_grad()

                batch_data = []
                if self.use_gpu:
                    for _ in self.gpus:
                        data = train_iterator.next()
                        batch_data.append(data)
                        iter_count += 1

                avg_train_loss = .0
                for ff in range(self.dataset_conf.num_fwd_pass):
                    batch_fwd = []

                    if self.use_gpu:
                        for dd, gpu_id in enumerate(self.gpus):
                            data = {}
                            data['adj'] = batch_data[dd][ff]['adj'].pin_memory(
                            ).to(gpu_id, non_blocking=True)
                            data['edges'] = batch_data[dd][ff][
                                'edges'].pin_memory().to(gpu_id,
                                                         non_blocking=True)
                            data['node_idx_gnn'] = batch_data[dd][ff][
                                'node_idx_gnn'].pin_memory().to(
                                    gpu_id, non_blocking=True)
                            data['node_idx_feat'] = batch_data[dd][ff][
                                'node_idx_feat'].pin_memory().to(
                                    gpu_id, non_blocking=True)
                            data['label'] = batch_data[dd][ff][
                                'label'].pin_memory().to(gpu_id,
                                                         non_blocking=True)
                            data['att_idx'] = batch_data[dd][ff][
                                'att_idx'].pin_memory().to(gpu_id,
                                                           non_blocking=True)
                            data['subgraph_idx'] = batch_data[dd][ff][
                                'subgraph_idx'].pin_memory().to(
                                    gpu_id, non_blocking=True)
                            batch_fwd.append((data, ))

                    if batch_fwd:
                        train_loss = model(*batch_fwd).mean()
                        avg_train_loss += train_loss

                        # assign gradient
                        train_loss.backward()

                # clip_grad_norm_(model.parameters(), 5.0e-0)
                optimizer.step()
                avg_train_loss /= float(self.dataset_conf.num_fwd_pass)

                # reduce
                train_loss = float(avg_train_loss.data.cpu().numpy())

                self.writer.add_scalar('train_loss', train_loss, iter_count)
                results['train_loss'] += [train_loss]
                results['train_step'] += [iter_count]

                if iter_count % self.train_conf.display_iter == 0 or iter_count == 1:
                    logger.info(
                        "NLL Loss @ epoch {:04d} iteration {:08d} = {}".format(
                            epoch + 1, iter_count, train_loss))

            # snapshot model
            if (epoch + 1) % self.train_conf.snapshot_epoch == 0:
                logger.info("Saving Snapshot @ epoch {:04d}".format(epoch + 1))
                snapshot(model.module if self.use_gpu else model,
                         optimizer,
                         self.config,
                         epoch + 1,
                         scheduler=lr_scheduler)

            if (epoch + 1) % 20 == 0 and not has_sampled:
                has_sampled = True
                print('saving graphs')
                model.eval()
                graphs_gen = [
                    get_graph(aa.cpu().data.numpy())
                    for aa in model.module._sampling(10)
                ]
                model.train()

                vis_graphs = []
                for gg in graphs_gen:
                    CGs = [gg.subgraph(c) for c in nx.connected_components(gg)]
                    CGs = sorted(CGs,
                                 key=lambda x: x.number_of_nodes(),
                                 reverse=True)
                    vis_graphs += [CGs[0]]

                total = len(vis_graphs)  #min(3, len(vis_graphs))
                draw_graph_list(vis_graphs[:total],
                                2,
                                int(total // 2),
                                fname='sample/gran_%d.png' % epoch,
                                layout='spring')

        pickle.dump(
            results,
            open(os.path.join(self.config.save_dir, 'train_stats.p'), 'wb'))
        self.writer.close()

        return 1
Beispiel #6
0
    def test(self):
        self.config.save_dir = self.test_conf.test_model_dir

        ### Compute Erdos-Renyi baseline
        if self.config.test.is_test_ER:
            p_ER = sum([
                aa.number_of_edges() for aa in self.graphs_train
            ]) / sum([aa.number_of_nodes()**2 for aa in self.graphs_train])
            graphs_gen = [
                nx.fast_gnp_random_graph(self.max_num_nodes, p_ER, seed=ii)
                for ii in range(self.num_test_gen)
            ]
        else:
            ### load model
            model = eval(self.model_conf.name)(self.config)
            model_file = os.path.join(self.config.save_dir,
                                      self.test_conf.test_model_name)
            load_model(model, model_file, self.device)

            # create graph classifier
            graph_classifier = GraphSAGE(3, 2, 3, 32, 'add')
            # graph_classifier = DiffPool(3, 2, max_num_nodes=630)
            # graph_classifier = DGCNN(3, 2, 'PROTEINS_full')
            graph_classifier.load_state_dict(
                torch.load('output/MODEL_PROTEINS.pkl'))

            if self.use_gpu:
                model = nn.DataParallel(model,
                                        device_ids=self.gpus).to(self.device)
                graph_classifier = graph_classifier.to(self.device)

            model.eval()
            graph_classifier.eval()

            ### Generate Graphs
            A_pred = []
            num_nodes_pred = []
            num_test_batch = 5000

            gen_run_time = []
            graph_acc_count = 0
            # for ii in tqdm(range(1)):
            #     with torch.no_grad():
            #         start_time = time.time()
            #         input_dict = {}
            #         input_dict['is_sampling'] = True
            #         input_dict['batch_size'] = self.test_conf.batch_size
            #         input_dict['num_nodes_pmf'] = self.num_nodes_pmf_train
            #         A_tmp, label_tmp = model(input_dict)
            #         gen_run_time += [time.time() - start_time]
            #         A_pred += [aa.data.cpu().numpy() for aa in A_tmp]
            #         num_nodes_pred += [aa.shape[0] for aa in A_tmp]

            from classifier.losses import MulticlassClassificationLoss
            classifier_loss = MulticlassClassificationLoss()

            ps = []

            acc_count_by_label = {0: 0, 1: 0}

            graph_acc_count = 0
            for ii in tqdm(range(2 * num_test_batch)):
                with torch.no_grad():
                    label = ii % 2
                    graph_label = torch.tensor([label]).to('cuda').long()
                    start_time = time.time()
                    input_dict = {}
                    input_dict['is_sampling'] = True
                    input_dict['batch_size'] = self.test_conf.batch_size
                    input_dict['num_nodes_pmf'] = self.num_nodes_pmf_by_group[
                        graph_label.item()]
                    input_dict['graph_label'] = graph_label

                    A_tmp, label_tmp = model(input_dict)
                    A_tmp = A_tmp[0]
                    label_tmp = label_tmp[0]

                    label_tmp = label_tmp.long()

                    lower_part = torch.tril(A_tmp, diagonal=-1)

                    x = torch.zeros((A_tmp.shape[0], 3)).to(self.device)
                    x[list(range(A_tmp.shape[0])), label_tmp] = 1

                    edge_mask = (lower_part != 0).to(self.device)
                    edges = edge_mask.nonzero().transpose(0, 1).to(self.device)
                    edges_other_way = edges[[1, 0]]
                    edges = torch.cat([edges, edges_other_way],
                                      dim=-1).to(self.device)

                    batch = torch.zeros(A_tmp.shape[0]).long().to(self.device)

                    data = Bunch(x=x,
                                 edge_index=edges,
                                 batch=batch,
                                 y=graph_label,
                                 edge_weight=None)

                    n_nodes = batch.shape[0]
                    n_edges = edges.shape[1]

                    output = graph_classifier(data)

                    if not isinstance(output, tuple):
                        output = (output, )

                    graph_classification_loss, graph_classification_acc = classifier_loss(
                        data.y, *output)
                    graph_acc_count += graph_classification_acc / 100

                    acc_count_by_label[label] += graph_classification_acc / 100

                    print(graph_classification_acc, graph_label)

                    if ii % 100 == 99:
                        n_graphs_each = (ii + 1) / 2
                        print("\033[92m" +
                              "Class 0: %.3f ----  Class 1: %.3f" %
                              (acc_count_by_label[0] / n_graphs_each,
                               acc_count_by_label[1] / n_graphs_each) +
                              "\033[0m")

            logger.info('Average test time per mini-batch = {}'.format(
                np.mean(gen_run_time)))
            for label in [0, 1]:
                graph_acc_count = acc_count_by_label[label]
                logger.info('Class %s: ' % (label) +
                            'Conditional graph generation accuracy = {}'.
                            format(graph_acc_count / num_test_batch))

            graphs_gen = [get_graph(aa) for aa in A_pred]

        ### Visualize Generated Graphs
        if self.is_vis:
            num_col = self.vis_num_row
            num_row = int(np.ceil(self.num_vis / num_col))
            test_epoch = self.test_conf.test_model_name
            test_epoch = test_epoch[test_epoch.rfind('_') +
                                    1:test_epoch.find('.pth')]
            save_name = os.path.join(
                self.config.save_dir,
                '{}_gen_graphs_epoch_{}_block_{}_stride_{}.png'.format(
                    self.config.test.test_model_name[:-4], test_epoch,
                    self.block_size, self.stride))

            # remove isolated nodes for better visulization
            graphs_pred_vis = [
                copy.deepcopy(gg) for gg in graphs_gen[:self.num_vis]
            ]

            if self.better_vis:
                for gg in graphs_pred_vis:
                    gg.remove_nodes_from(list(nx.isolates(gg)))

            # display the largest connected component for better visualization
            vis_graphs = []
            for gg in graphs_pred_vis:
                CGs = [gg.subgraph(c) for c in nx.connected_components(gg)]
                CGs = sorted(CGs,
                             key=lambda x: x.number_of_nodes(),
                             reverse=True)
                vis_graphs += [CGs[0]]

            if self.is_single_plot:
                draw_graph_list(vis_graphs,
                                num_row,
                                num_col,
                                fname=save_name,
                                layout='spring')
            else:
                draw_graph_list_separate(vis_graphs,
                                         fname=save_name[:-4],
                                         is_single=True,
                                         layout='spring')

            save_name = os.path.join(self.config.save_dir, 'train_graphs.png')

            if self.is_single_plot:
                draw_graph_list(self.graphs_train[:self.num_vis],
                                num_row,
                                num_col,
                                fname=save_name,
                                layout='spring')
            else:
                draw_graph_list_separate(self.graphs_train[:self.num_vis],
                                         fname=save_name[:-4],
                                         is_single=True,
                                         layout='spring')

        ### Evaluation
        if self.config.dataset.name in ['lobster']:
            acc = eval_acc_lobster_graph(graphs_gen)
            logger.info(
                'Validity accuracy of generated graphs = {}'.format(acc))

        num_nodes_gen = [len(aa) for aa in graphs_gen]

        # Compared with Validation Set
        num_nodes_dev = [len(gg.nodes)
                         for gg in self.graphs_dev]  # shape B X 1
        mmd_degree_dev, mmd_clustering_dev, mmd_4orbits_dev, mmd_spectral_dev = evaluate(
            self.graphs_dev, graphs_gen, degree_only=False)
        mmd_num_nodes_dev = compute_mmd([np.bincount(num_nodes_dev)],
                                        [np.bincount(num_nodes_gen)],
                                        kernel=gaussian_emd)

        # Compared with Test Set
        num_nodes_test = [len(gg.nodes)
                          for gg in self.graphs_test]  # shape B X 1
        mmd_degree_test, mmd_clustering_test, mmd_4orbits_test, mmd_spectral_test = evaluate(
            self.graphs_test, graphs_gen, degree_only=False)
        mmd_num_nodes_test = compute_mmd([np.bincount(num_nodes_test)],
                                         [np.bincount(num_nodes_gen)],
                                         kernel=gaussian_emd)

        logger.info(
            "Validation MMD scores of #nodes/degree/clustering/4orbits/spectral are = {}/{}/{}/{}/{}"
            .format(mmd_num_nodes_dev, mmd_degree_dev, mmd_clustering_dev,
                    mmd_4orbits_dev, mmd_spectral_dev))
        logger.info(
            "Test MMD scores of #nodes/degree/clustering/4orbits/spectral are = {}/{}/{}/{}/{}"
            .format(mmd_num_nodes_test, mmd_degree_test, mmd_clustering_test,
                    mmd_4orbits_test, mmd_spectral_test))

        if self.config.dataset.name in ['lobster']:
            return mmd_degree_dev, mmd_clustering_dev, mmd_4orbits_dev, mmd_spectral_dev, mmd_degree_test, mmd_clustering_test, mmd_4orbits_test, mmd_spectral_test, acc
        else:
            return mmd_degree_dev, mmd_clustering_dev, mmd_4orbits_dev, mmd_spectral_dev, mmd_degree_test, mmd_clustering_test, mmd_4orbits_test, mmd_spectral_test
Beispiel #7
0
    def train(self):
        ### create data loader
        train_dataset = eval(self.dataset_conf.loader_name)(self.config,
                                                            self.graphs_train,
                                                            tag='train')
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=self.train_conf.batch_size,
            shuffle=self.train_conf.shuffle,
            num_workers=self.train_conf.num_workers,
            collate_fn=train_dataset.collate_fn,
            drop_last=False)

        # create models
        # model = eval(self.model_conf.name)(self.config)
        from model.transformer import make_model
        model = make_model(max_node=self.config.model.max_num_nodes,
                           d_out=20,
                           N=7,
                           d_model=64,
                           d_ff=64,
                           dropout=0.4)  # d_out, N, d_model, d_ff, h
        # d_out=20, N=15, d_model=16, d_ff=16, dropout=0.2) # d_out, N, d_model, d_ff, h
        # d_out=20, N=3, d_model=64, d_ff=64, dropout=0.1) # d_out, N, d_model, d_ff, h

        if self.use_gpu:
            model = DataParallel(model, device_ids=self.gpus).to(self.device)

        # create optimizer
        params = filter(lambda p: p.requires_grad, model.parameters())
        if self.train_conf.optimizer == 'SGD':
            optimizer = optim.SGD(params,
                                  lr=self.train_conf.lr,
                                  momentum=self.train_conf.momentum,
                                  weight_decay=self.train_conf.wd)
        elif self.train_conf.optimizer == 'Adam':
            optimizer = optim.Adam(params,
                                   lr=self.train_conf.lr,
                                   weight_decay=self.train_conf.wd)
        else:
            raise ValueError("Non-supported optimizer!")

        early_stop = EarlyStopper([0.0], win_size=100, is_decrease=False)
        lr_scheduler = optim.lr_scheduler.MultiStepLR(
            optimizer,
            milestones=self.train_conf.lr_decay_epoch,
            gamma=self.train_conf.lr_decay)

        # reset gradient
        optimizer.zero_grad()

        # resume training
        resume_epoch = 0
        if self.train_conf.is_resume:
            model_file = os.path.join(self.train_conf.resume_dir,
                                      self.train_conf.resume_model)
            load_model(model.module if self.use_gpu else model,
                       model_file,
                       self.device,
                       optimizer=optimizer,
                       scheduler=lr_scheduler)
            resume_epoch = self.train_conf.resume_epoch

        # Training Loop
        iter_count = 0
        results = defaultdict(list)
        for epoch in range(resume_epoch, self.train_conf.max_epoch):
            model.train()
            lr_scheduler.step()
            train_iterator = train_loader.__iter__()

            for inner_iter in range(len(train_loader) // self.num_gpus):
                optimizer.zero_grad()

                batch_data = []
                if self.use_gpu:
                    for _ in self.gpus:
                        data = train_iterator.next()
                        batch_data += [data]

                avg_train_loss = .0
                for ff in range(self.dataset_conf.num_fwd_pass):
                    batch_fwd = []

                    if self.use_gpu:
                        for dd, gpu_id in enumerate(self.gpus):
                            data = batch_data[dd]

                            adj, lens = data['adj'], data['lens']

                            # this is only for grid
                            # adj = adj[:, :, :100, :100]
                            # lens = [min(99, x) for x in lens]

                            adj = adj.to('cuda:%d' % gpu_id)

                            # build masks
                            node_feat, attn_mask, lens = preprocess(adj, lens)
                            batch_fwd.append(
                                (node_feat, attn_mask.clone(), lens))

                    if batch_fwd:
                        node_feat, attn_mask, lens = batch_fwd[0]
                        log_theta, log_alpha = model(*batch_fwd)

                        train_loss = model.module.mix_bern_loss(
                            log_theta, log_alpha, adj, lens)

                        avg_train_loss += train_loss

                        # assign gradient
                        train_loss.backward()

                # clip_grad_norm_(model.parameters(), 5.0e-0)
                optimizer.step()
                avg_train_loss /= float(self.dataset_conf.num_fwd_pass)

                # reduce
                train_loss = float(avg_train_loss.data.cpu().numpy())

                self.writer.add_scalar('train_loss', train_loss, iter_count)
                results['train_loss'] += [train_loss]
                results['train_step'] += [iter_count]

                if iter_count % self.train_conf.display_iter == 0 or iter_count == 1:
                    logger.info(
                        "NLL Loss @ epoch {:04d} iteration {:08d} = {}".format(
                            epoch + 1, iter_count, train_loss))

                if epoch % 50 == 0 and inner_iter == 0:
                    model.eval()
                    print('saving graphs')
                    graphs_gen = [get_graph(adj[0].cpu().data.numpy())] + [
                        get_graph(aa.cpu().data.numpy())
                        for aa in model.module.sample(
                            19, max_node=self.config.model.max_num_nodes)
                    ]
                    model.train()

                    vis_graphs = []
                    for gg in graphs_gen:
                        CGs = [
                            gg.subgraph(c) for c in nx.connected_components(gg)
                        ]
                        CGs = sorted(CGs,
                                     key=lambda x: x.number_of_nodes(),
                                     reverse=True)
                        try:
                            vis_graphs += [CGs[0]]
                        except:
                            pass

                    try:
                        total = len(vis_graphs)  #min(3, len(vis_graphs))
                        draw_graph_list(vis_graphs[:total],
                                        4,
                                        int(total // 4),
                                        fname='sample/trans_sl:%d_%d.png' %
                                        (int(model.module.self_loop), epoch),
                                        layout='spring')
                    except:
                        print('sample saving failed')

            # snapshot model
            if (epoch + 1) % self.train_conf.snapshot_epoch == 0:
                logger.info("Saving Snapshot @ epoch {:04d}".format(epoch + 1))
                snapshot(model.module if self.use_gpu else model,
                         optimizer,
                         self.config,
                         epoch + 1,
                         scheduler=lr_scheduler)

        pickle.dump(
            results,
            open(os.path.join(self.config.save_dir, 'train_stats.p'), 'wb'))
        self.writer.close()

        return 1