示例#1
0
if __name__ == '__main__':
    nmt_parser = argparse.ArgumentParser()
    add_arguments(nmt_parser)
    FLAGS, unparsed = nmt_parser.parse_known_args()
    hparams = create_hparams(FLAGS)

    # loading the data from a file
    adj, features, edges = load_data(hparams.graph_file, hparams.nodes)
    num_nodes = adj[0].shape[0]
    num_features = features[0].shape[1]
    #print("Debug", num_nodes, adj[0][0])
    # Training
    model = VAEG(hparams, placeholders, num_nodes, num_features, edges)
    # model.restore(hparams.out_dir)
    model.initialize()
    model.train(placeholders, hparams, adj, features)

    # Test code
    '''
    model2 = VAEG(hparams, placeholders, 30, 1)
    model2.restore(hparams.out_dir)
    hparams.sample = True
    i = 0
    G_good = load_embeddings(hparams.z_dir+'train0.txt')
    G_bad = load_embeddings(hparams.z_dir+'test_11.txt')
    
    #model2.sample_graph_slerp(hparams, placeholders, 5, G_good, G_bad, num=29)
    
    while i < 10:
        G_bad = model2.sample_graph_slerp(hparams, placeholders, i, G_good, G_bad, num=29)
示例#2
0
      sample=flags.sample,
      neg_sample_size=flags.neg_sample_size,
      node_sample=flags.node_sample,
      bfs_sample=flags.bfs_sample
      )

if __name__ == '__main__':
    nmt_parser = argparse.ArgumentParser()
    add_arguments(nmt_parser)
    FLAGS, unparsed = nmt_parser.parse_known_args()
    hparams = create_hparams(FLAGS)
    
    # loading the data from a file
    adj, weight, weight_bin, features, edges, neg_edges, features1, = load_data_new(hparams.graph_file, hparams.nodes, hparams.node_sample, hparams.bfs_sample, hparams.bin_dim)
    num_nodes = adj[0].shape[0]
    num_features = features[0].shape[1]
    lenedges = [len(edge[0]) for edge in edges]
    lenweight_bin = [len(weight_b[0]) for weight_b in weight_bin]
    print("Len edges", lenedges, lenweight_bin)
    print("Num features", num_features)
    print("Num examples", len(adj))
    #print("Neg_index", neg_index) 
    e = max([len(edge) for edge in edges])
        
    log_fact_k = log_fact(e)
    # Training
    #'''
    model = VAEG(hparams, placeholders, num_nodes, num_features,log_fact_k, len(adj))
    model.restore(hparams.out_dir)
    model.train(placeholders, hparams, adj, weight, weight_bin, features, edges, neg_edges, features1)