Exemple #1
0
def train_graph_completion(args, dataset_test, rnn, output):
    fname = args.model_save_path + args.fname + "lstm_" + str(
        args.load_epoch) + ".dat"
    rnn.load_state_dict(torch.load(fname))
    fname = (args.model_save_path + args.fname + "output_" +
             str(args.load_epoch) + ".dat")
    output.load_state_dict(torch.load(fname))

    epoch = args.load_epoch
    print("model loaded!, epoch: {}".format(args.load_epoch))

    for sample_time in range(1, 4):
        if "GraphRNN_MLP" in args.note:
            G_pred = test_mlp_partial_simple_epoch(epoch,
                                                   args,
                                                   rnn,
                                                   output,
                                                   dataset_test,
                                                   sample_time=sample_time)
        if "GraphRNN_VAE" in args.note:
            G_pred = test_vae_partial_epoch(epoch,
                                            args,
                                            rnn,
                                            output,
                                            dataset_test,
                                            sample_time=sample_time)
        # save graphs
        fname = (args.graph_save_path + args.fname_pred + str(epoch) + "_" +
                 str(sample_time) + "graph_completion.dat")
        save_graph_list(G_pred, fname)
    print("graph completion done, graphs saved")
Exemple #2
0
def just_generate(gg_model, dataset_train, args, gen_iter):
    if args.estimate_num_nodes:
        print('estimation of num_nodes_prob started')
        gg_model.num_nodes_prob = np.zeros(args.max_num_node + 1)
        for epoch in range(10):
            print(epoch, ' ', end='')
            sys.stdout.flush()
            for data in dataset_train:
                adj = data['adj'].to(args.device)
                for a in adj:
                    idx = a.sum(dim=0).bool().sum().item()
                    gg_model.num_nodes_prob[idx] += 1
        gg_model.num_nodes_prob = gg_model.num_nodes_prob / gg_model.num_nodes_prob.sum()
        print('estimation of num_nodes_prob finished')

    load_pretrained_model_weights(gg_model, gen_iter, args)


    for sample_time in range(1,2): #4):
        print('     sample_time:', sample_time)
        G_pred = []
        while len(G_pred)<args.test_total_size:
            print('        len(G_pred):', len(G_pred))
            G_pred_step = generate_graph(gg_model, args)
            G_pred.extend(G_pred_step)
        # save graphs
        fname = args.graph_save_path + args.fname_pred + str(gen_iter) + '_' + str(sample_time) + '.dat'
        utils.save_graph_list(G_pred, fname)
    print('test done, graphs saved')
                 for k in range(10):
                     graphs.append(caveman_special(i, j, p_edge=0.3))
         utils.export_graphs_to_txt(graphs, output_prefix)
     elif prog_args.graph_type == "citeseer":
         graphs = utils.citeseer_ego()
         utils.export_graphs_to_txt(graphs, output_prefix)
     else:
         # load from directory
         input_path = dir_prefix + args.graph_save_path + args.fname_test + "0.dat"
         g_list = utils.load_graph_list(input_path)
         utils.export_graphs_to_txt(g_list, output_prefix)
 elif not prog_args.kron_dir == "":
     kron_g_list = process_kron(prog_args.kron_dir)
     fname = os.path.join(prog_args.kron_dir, prog_args.graph_type + ".dat")
     print([g.number_of_nodes() for g in kron_g_list])
     utils.save_graph_list(kron_g_list, fname)
 elif not prog_args.test_file == "":
     # evaluate single .dat file containing list of test graphs (networkx format)
     graphs = utils.load_graph_list(prog_args.test_file)
     eval_single_list(graphs,
                      dir_input=dir_prefix + "graphs/",
                      dataset_name="grid")
 ## if you don't try kronecker, only the following part is needed
 else:
     if not os.path.isdir(dir_prefix + "eval_results"):
         os.makedirs(dir_prefix + "eval_results")
     evaluation(
         args_evaluate,
         dir_input=dir_prefix + "graphs/",
         dir_output=dir_prefix + "eval_results/",
         model_name_all=args_evaluate.model_name_all,
    args.max_num_node = max([graphs[i].number_of_nodes() for i in range(len(graphs))])
    max_num_edge = max([graphs[i].number_of_edges() for i in range(len(graphs))])
    min_num_edge = min([graphs[i].number_of_edges() for i in range(len(graphs))])

    # args.max_num_node = 2000
    # show graphs statistics
    print(
        "total graph num: {}, training set: {}".format(len(graphs), len(graphs_train))
    )
    print("max number node: {}".format(args.max_num_node))
    print("max/min number edge: {}; {}".format(max_num_edge, min_num_edge))
    print("max previous node: {}".format(args.max_prev_node))

    # save ground truth graphs
    ## To get train and test set, after loading you need to manually slice
    save_graph_list(graphs, os.path.join(args.graph_save_path,args.fname_train + "0.dat"))
    save_graph_list(graphs, os.path.join(args.graph_save_path,args.fname_test + "0.dat"))
    print(
        "train and test graphs saved at: ",
        os.path.join(args.graph_save_path, args.fname_test + "0.dat"),
    )

    ### comment when normal training, for graph completion only
    # remove_edges(graphs_lists)

    ### dataset initialization
    if "nobfs" in args.note:
        print("nobfs")
        dataset = Graph_sequence_sampler_pytorch_nobfs(
            graphs_train, max_num_node=args.max_num_node
        )
Exemple #5
0
def train(args, dataset_train, rnn, output):
    # check if load existing model
    if args.load:
        fname = (args.model_save_path + args.fname + "lstm_" +
                 str(args.load_epoch) + ".dat")
        rnn.load_state_dict(torch.load(fname))
        fname = (args.model_save_path + args.fname + "output_" +
                 str(args.load_epoch) + ".dat")
        output.load_state_dict(torch.load(fname))

        args.lr = 0.00001
        epoch = args.load_epoch
        print("model loaded!, lr: {}".format(args.lr))
    else:
        epoch = 1

    # initialize optimizer
    optimizer_rnn = optim.Adam(list(rnn.parameters()), lr=args.lr)
    optimizer_output = optim.Adam(list(output.parameters()), lr=args.lr)

    scheduler_rnn = MultiStepLR(optimizer_rnn,
                                milestones=args.milestones,
                                gamma=args.lr_rate)
    scheduler_output = MultiStepLR(optimizer_output,
                                   milestones=args.milestones,
                                   gamma=args.lr_rate)

    # start main loop
    time_all = np.zeros(args.epochs)
    while epoch <= args.epochs:
        time_start = tm.time()
        # train
        if "GraphRNN_VAE" in args.note:
            train_vae_epoch(
                epoch,
                args,
                rnn,
                output,
                dataset_train,
                optimizer_rnn,
                optimizer_output,
                scheduler_rnn,
                scheduler_output,
            )
        elif "GraphRNN_MLP" in args.note:
            train_mlp_epoch(
                epoch,
                args,
                rnn,
                output,
                dataset_train,
                optimizer_rnn,
                optimizer_output,
                scheduler_rnn,
                scheduler_output,
            )
        elif "GraphRNN_RNN" in args.note:
            train_rnn_epoch(
                epoch,
                args,
                rnn,
                output,
                dataset_train,
                optimizer_rnn,
                optimizer_output,
                scheduler_rnn,
                scheduler_output,
            )
        time_end = tm.time()
        time_all[epoch - 1] = time_end - time_start

        # test
        if epoch % args.epochs_test == 0 and epoch >= args.epochs_test_start:
            for sample_time in range(1, 4):
                G_pred = []
                while len(G_pred) < args.test_total_size:
                    if "GraphRNN_VAE" in args.note:
                        G_pred_step = test_vae_epoch(
                            epoch,
                            args,
                            rnn,
                            output,
                            test_batch_size=args.test_batch_size,
                            sample_time=sample_time,
                        )
                    elif "GraphRNN_MLP" in args.note:
                        G_pred_step = test_mlp_epoch(
                            epoch,
                            args,
                            rnn,
                            output,
                            test_batch_size=args.test_batch_size,
                            sample_time=sample_time,
                        )
                    elif "GraphRNN_RNN" in args.note:
                        G_pred_step = test_rnn_epoch(
                            epoch,
                            args,
                            rnn,
                            output,
                            test_batch_size=args.test_batch_size,
                        )
                    G_pred.extend(G_pred_step)
                # save graphs
                fname = (args.graph_save_path + args.fname_pred + str(epoch) +
                         "_" + str(sample_time) + ".dat")
                save_graph_list(G_pred, fname)
                draw_graph(random.choice(G_pred), prefix=f"collagen-{epoch}")
                if "GraphRNN_RNN" in args.note:
                    break
            print("test done, graphs saved")

        # save model checkpoint
        if args.save:
            if epoch % args.epochs_save == 0:
                fname = (args.model_save_path + args.fname + "lstm_" +
                         str(epoch) + ".dat")
                torch.save(rnn.state_dict(), fname)
                fname = (args.model_save_path + args.fname + "output_" +
                         str(epoch) + ".dat")
                torch.save(output.state_dict(), fname)
        epoch += 1
    np.save(args.timing_save_path + args.fname, time_all)
Exemple #6
0
                    for k in range(10):
                        graphs.append(caveman_special(i,j, p_edge=0.3))
            utils.export_graphs_to_txt(graphs, output_prefix)
        elif prog_args.graph_type == 'citeseer':
            graphs = utils.citeseer_ego()
            utils.export_graphs_to_txt(graphs, output_prefix)
        else:
            # load from directory
            input_path = dir_prefix + real_graph_filename
            g_list = utils.load_graph_list(input_path)
            utils.export_graphs_to_txt(g_list, output_prefix)
    elif not prog_args.kron_dir == '':
        kron_g_list = process_kron(prog_args.kron_dir)
        fname = os.path.join(prog_args.kron_dir, prog_args.graph_type + '.dat')
        print([g.number_of_nodes() for g in kron_g_list])
        utils.save_graph_list(kron_g_list, fname)
    elif not prog_args.test_file == '':
        # evaluate single .dat file containing list of test graphs (networkx format)
        graphs = utils.load_graph_list(prog_args.test_file)
        eval_single_list(graphs, dir_input=dir_prefix+'graphs/', dataset_name='grid')
    ## if you don't try kronecker, only the following part is needed
    else:
        if not os.path.isdir(dir_prefix+'eval_results'):
            os.makedirs(dir_prefix+'eval_results')
        evaluation(args_evaluate,dir_input=dir_prefix+"graphs/", dir_output=dir_prefix+"eval_results/",
                   model_name_all=args_evaluate.model_name_all,dataset_name_all=args_evaluate.dataset_name_all,args=args,overwrite=True)




Exemple #7
0
        print('Creating dataset with ', num_communities, ' communities')
        c_sizes = np.random.choice([12, 13, 14, 15, 16, 17], num_communities)
        for k in range(3000):
            graphs.append(utils.n_community(c_sizes, p_inter=0.01))
        X_dataset = [nx.to_numpy_matrix(g) for g in graphs]

    print('Number of graphs: ', len(X_dataset))
    K = prog_args.K  # number of clusters
    gen_graphs = []
    for i in range(len(X_dataset)):
        if i % 5 == 0:
            print(i)
            X_data = X_dataset[i]
            N = X_data.shape[0]  # number of vertices

            Zp, B = mmsb(N, K, X_data)
            #print("Block: ", B)
            Z_pred = Zp.argmax(axis=1)
            print("Result (label flip can happen):")
            #print("prob: ", Zp)
            print("Predicted")
            print(Z_pred)
            #print(Z_true)
            #print("Adjusted Rand Index =", adjusted_rand_score(Z_pred, Z_true))
            for j in range(prog_args.samples):
                gen_graphs.append(graph_gen_from_blockmodel(B, Zp))

    save_path = '/lfs/local/0/rexy/graph-generation/eval_results/mmsb/'
    utils.save_graph_list(gen_graphs,
                          os.path.join(save_path, prog_args.dataset + '.dat'))
Exemple #8
0
def train(gg_model, dataset_train, dataset_validation, dataset_test, optimizer, args):

    ## initialize optimizer
    ## optimizer = torch.optim.Adam(list(gcade_model.parameters()), lr=args.lr)
    ## scheduler = MultiStepLR(optimizer, milestones=args.milestones, gamma=args.lr_rate)

    if args.estimate_num_nodes or args.weight_positions:
        print('estimation of num_nodes_prob started')
        num_nodes_prob = np.zeros(args.max_num_node + 1)
        for epoch in range(10):
            print(epoch, ' ', end='')
            sys.stdout.flush()
            for data in dataset_train:
                adj = data['adj'].to(args.device)
                for a in adj:
                    idx = a.sum(dim=0).bool().sum().item()
                    num_nodes_prob[idx] += 1
        num_nodes_prob = num_nodes_prob / num_nodes_prob.sum()
        print('estimation of num_nodes_prob finished')
        if args.estimate_num_nodes:
            gg_model.num_nodes_prob = num_nodes_prob
        if args.weight_positions:
            tmp = np.cumsum(num_nodes_prob, axis=0)
            tmp = 1 - tmp[:-1]
            tmp = np.concatenate([np.array([1.]), tmp])
            tmp[tmp <= 0] = np.min(tmp[tmp > 0])
            position_weights = 1 / tmp
            gg_model.positions_weights = torch.tensor(position_weights).to(args.device).view(1, -1)



    # start main loop
    time_all = np.zeros(args.epochs)
    loss_buffer = []
    if args.epoch_train_start > 0:
        load_pretrained_model_weights(gg_model, args.epoch_train_start - 1, args)
        optimizer.set_n_steps(args.epoch_train_start * args.batch_ratio)
    for epoch in range(args.epoch_train_start, args.epochs):
        time_start = time.time()
        running_loss = 0.0
        trsz = 0
        gg_model.train()
        for i, data in enumerate(dataset_train, 0):
            if args.use_MADE:
                gg_model.trg_word_MADE.update_masks()
            # print(' #', i)
            print('.', end='')
            sys.stdout.flush()
            src_seq = data['src_seq'].to(args.device)
            trg_seq = data['src_seq'].to(args.device)

            '''
            for j in range(src_seq.size(1)):
                ind = src_seq[:,j,0] == args.zero_input
                tmp = args.dontcare_input * torch.ones(ind.sum().item(), src_seq.size(-1)).to(args.device)

                # tmp[:, :] = args.zero_input
                tmp[:, :j] = args.zero_input
                tmp[:, j] = args.one_input

                # tmp[:, :] = args.one_input
                # tmp[:, 0] = args.zero_input

                src_seq[ind, j,  :] = tmp.clone()
                trg_seq[ind, j,  :] = tmp.clone()
            '''

            gold = data['trg_seq'].contiguous().to(args.device)
            adj = data['adj'].to(args.device)

            optimizer.zero_grad()
            pred, dec_output = gg_model(src_seq, trg_seq, gold, adj)
            if (not args.weight_termination_bit) or (epoch > args.termination_bit_weight_last_epoch):
                loss, *_ = cal_performance( pred, dec_output, gold, trg_pad_idx=0, args=args, model=gg_model, smoothing=False)
            else:
                tmp = (args.termination_bit_weight_last_epoch - epoch) / args.termination_bit_weight_last_epoch
                termination_bit_weight = (tmp ** 2) * (args.termination_bit_weight - 1) + 1

                print('                   tbw: ', termination_bit_weight)
                loss, *_ = cal_performance( pred, dec_output, gold, trg_pad_idx=0, args=args, model=gg_model,
                                            termination_bit_weight=termination_bit_weight, smoothing=False)

            # print('  ', loss.item() / input_nodes.size(0))
            loss.backward()
            optimizer.step_and_update_lr()

            running_loss += loss.item()
            trsz += src_seq.size(0)

        val_running_loss = 0.0
        vlsz = 0
        gg_model.eval()
        for i, data in enumerate(dataset_validation):
            if args.use_MADE:
                gg_model.trg_word_MADE.update_masks()
            src_seq = data['src_seq'].to(args.device)
            trg_seq = data['src_seq'].to(args.device) 
            gold = data['trg_seq'].contiguous().to(args.device)
            adj = data['adj'].to(args.device)

            pred, dec_output = gg_model(src_seq, trg_seq, gold, adj)
            loss, *_ = cal_performance( pred, dec_output, gold, trg_pad_idx=0, args=args, model=gg_model, smoothing=False)

            val_running_loss += loss.item()
            vlsz += src_seq.size(0)

        test_running_loss = 0.0
        testsz = 0
        gg_model.eval()
        for i, data in enumerate(dataset_test):
            if args.use_MADE:
                gg_model.trg_word_MADE.update_masks()
            src_seq = data['src_seq'].to(args.device)
            trg_seq = data['src_seq'].to(args.device)
            gold = data['trg_seq'].contiguous().to(args.device)
            adj = data['adj'].to(args.device)

            pred, dec_output = gg_model(src_seq, trg_seq, gold, adj)
            loss, *_ = cal_performance(pred, dec_output, gold, trg_pad_idx=0, args=args, model=gg_model, smoothing=False)

            test_running_loss += loss.item()
            testsz += src_seq.size(0)

        if epoch % args.epochs_save == 0:
            fname = args.model_save_path + args.fname + '_' + args.graph_type + '_'  + str(epoch) + '.dat'
            torch.save(gg_model.state_dict(), fname)

        loss_buffer.append(running_loss / trsz)
        if len(loss_buffer) > 5:
            loss_buffer = loss_buffer[1:]
        print('[epoch %d]     loss: %.3f     val: %.3f     test: %.3f              lr: %f     avg_tr_loss: %f' %
              (epoch + 1, running_loss / trsz, val_running_loss / vlsz, test_running_loss / testsz, optimizer._optimizer.param_groups[0]['lr'], np.mean(loss_buffer))) #get_lr(optimizer)))
        # print(list(gg_model.encoder.layer_stack[0].slf_attn.gr_att_linear_list[0].parameters()))
        sys.stdout.flush()
        time_end = time.time()
        time_all[epoch - 1] = time_end - time_start
        # test
        if epoch % args.epochs_test == 0 and epoch >= args.epochs_test_start:
            for sample_time in range(1,2): #4):
                print('     sample_time:', sample_time)
                G_pred = []
                while len(G_pred)<args.test_total_size:
                    print('        len(G_pred):', len(G_pred))
                    G_pred_step = generate_graph(gg_model, args)
                    G_pred.extend(G_pred_step)
                # save graphs
                fname = args.graph_save_path + args.fname_pred + str(epoch) + '_' + str(sample_time) + '.dat'
                utils.save_graph_list(G_pred, fname)
            print('test done, graphs saved')
Exemple #9
0
    # graphs_train = graphs[0:int(0.8 * graphs_len)]
    # graphs_validate = graphs[int(0.2 * graphs_len):int(0.4 * graphs_len)]

else:
    random.shuffle(graphs)

graphs_len = len(graphs)
graphs_test = graphs[int((1 - args.test_portion) * graphs_len):]
graphs_train = graphs[0:int(args.training_portion * graphs_len)]
graphs_validate = graphs[int((1 - args.test_portion - args.validation_portion) * graphs_len):
                         int((1 - args.test_portion) * graphs_len)]

if not args.use_pre_saved_graphs:
    # save ground truth graphs
    ## To get train and test set, after loading you need to manually slice
    save_graph_list(graphs, args.graph_save_path + args.fname_train + '0.dat')
    save_graph_list(graphs, args.graph_save_path + args.fname_test + '0.dat')
    print('train and test graphs saved at: ', args.graph_save_path + args.fname_test + '0.dat')

graph_validate_len = 0
for graph in graphs_validate:
    graph_validate_len += graph.number_of_nodes()
graph_validate_len /= len(graphs_validate)
print('graph_validate_len', graph_validate_len)

graph_test_len = 0
for graph in graphs_test:
    graph_test_len += graph.number_of_nodes()
graph_test_len /= len(graphs_test)
print('graph_test_len', graph_test_len)
Exemple #10
0
        print('Creating dataset with ', num_communities, ' communities')
        c_sizes = np.random.choice([12, 13, 14, 15, 16, 17], num_communities)
        for k in range(3000):
            graphs.append(utils.n_community(c_sizes, p_inter=0.01))
        X_dataset = [nx.to_numpy_matrix(g) for g in graphs]

    print('Number of graphs: ', len(X_dataset))
    K = prog_args.K  # number of clusters
    gen_graphs = []
    for i in range(len(X_dataset)):
        if i % 5 == 0:
            print(i)
            X_data = X_dataset[i]
            N = X_data.shape[0]  # number of vertices

            Zp, B = mmsb(N, K, X_data)
            #print("Block: ", B)
            Z_pred = Zp.argmax(axis=1)
            print("Result (label flip can happen):")
            #print("prob: ", Zp)
            print("Predicted")
            print(Z_pred)
            #print(Z_true)
            #print("Adjusted Rand Index =", adjusted_rand_score(Z_pred, Z_true))
            for j in range(prog_args.samples):
                gen_graphs.append(graph_gen_from_blockmodel(B, Zp))

    save_path = '/lfs/local/0/rexy/graph-generation/eval_results/mmsb/'
    utils.save_graph_list(gen_graphs, os.path.join(save_path, prog_args.dataset + '.dat'))