Ejemplo n.º 1
0
def create_graph_facebook(main_folder=r'C:\nam\work\facebook'):
    '''
    create graph facebook from edge file
    :return: 
    '''
    # G = nx.Graph()
    file = r'C:\nam\work\facebook\facebook_combined.txt'
    G = nx.read_edgelist(file, nodetype=int, create_using=nx.Graph())
    for edge in G.edges():
        G[edge[0]][edge[1]]['weight'] = 1
        G[edge[0]][edge[1]]['train_removed'] = False
        G[edge[0]][edge[1]]['test_removed'] = False
    for node in G.nodes():
        G.node[node]['val'] = False
        G.node[node]['test'] = False
        # # thêm feature hay không
        # G.node[node]['feature'] = (node,)
    # load graph ra file Graph.json
    with open(os.path.join(file, 'face-G.json'), 'w') as outfile1:
        outfile1.write(json.dumps(json_graph.node_link_data(G)))

    #create id_map.json
    nodes = list(G.nodes())
    id_map = {}
    for node in nodes:
        string_id = str(node)
        id_map[string_id] = node
    with open(os.path.join(file, 'face-id_map.json'), 'w') as outfile1:
        outfile1.write(json.dumps(id_map))

    # create walk file, chủ yếu là load fucntion  run_random_walks trong utils.py
    nodes = [
        n for n in G.nodes() if not G.node[n]["val"] and not G.node[n]["test"]
    ]
    G = G.subgraph(nodes)
    # đoạn random walk này hoàn toàn có thể dung random_walk của node2vec.
    pairs = run_random_walks(G, nodes)
    with open(os.path.join(file, 'face-walks1.json'), "w") as fp:
        fp.write("\n".join([str(p[0]) + "\t" + str(p[1]) for p in pairs]))

    # create class_map file

    nodes = list(G.nodes())
    class_map = {}
    for node in nodes:
        string_id = str(node)
        class_map[string_id] = [
            1,
        ]
    with open(os.path.join(file, 'face-class_map.json'), 'w') as outfile1:
        outfile1.write(json.dumps(id_map))
Ejemplo n.º 2
0
def dynamic_test(init_file, dynamic_file, flag_file):
    [G, feats, id_map, walks, class_map, node_flag,
     flag_no] = init_G(init_file, flag_file, FLAGS.feature_size)
    print(id_map)
    print(FLAGS.isTrain)
    raw_input()
    train([G, np.array(feats), id_map, walks, class_map])
    change_G_status(G)
    print("init succeed! num of edges in G is " + G.number_of_edges())
    raw_input()
    with open(dynamic_file, 'r') as infile:
        while True:
            line = infile.readline()
            if not line:
                break
            if line == '\r\n':
                continue
            [node_id, m] = line.strip().split()
            edges_added = []
            for i in range(m):
                line = infile.readline()
                items = line.strip().split()
                edges_added.append([int(it) for it in items])
            update_G(G, feats, id_map, class_map, edges_added, node_flag,
                     flag_no, FLAGS.feature_size)
            print("init succeed! num of edges in G is " + G.number_of_edges())
            raw_input()
            # construct walks
            nodes = [
                n for n in G.nodes()
                if not G.node[n]['val'] and not G.node[n]['test']
            ]
            G_part = G.subgraph(nodes)
            walks = run_random_walks(G_part, nodes)

            # train model
            train([G, feats, id_map, walks, class_map])
            change_G_status(G)
Ejemplo n.º 3
0
def train(train_data, test_data=None):
    G = train_data[0]
    features = train_data[1]
    id_map = train_data[2]

    if not features is None:
        # pad with dummy zero vector
        features = np.vstack([features, np.zeros((features.shape[1], ))])

    context_pairs = train_data[3] if FLAGS.random_context else None
    placeholders = construct_placeholders()
    minibatch = EdgeMinibatchIterator(G,
                                      id_map,
                                      placeholders,
                                      batch_size=FLAGS.batch_size,
                                      max_degree=FLAGS.max_degree,
                                      num_neg_samples=FLAGS.neg_sample_size,
                                      context_pairs=context_pairs)
    adj_info_ph = tf.placeholder(tf.int32, shape=minibatch.adj.shape)
    adj_info = tf.Variable(adj_info_ph, trainable=False, name="adj_info")

    if FLAGS.model == 'graphsage_mean':
        # Create model
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   logging=True)
    elif FLAGS.model == 'gcn':
        # Create model
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, 2 * FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, 2 * FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   aggregator_type="gcn",
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   concat=False,
                                   logging=True)

    elif FLAGS.model == 'graphsage_seq':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   identity_dim=FLAGS.identity_dim,
                                   aggregator_type="seq",
                                   model_size=FLAGS.model_size,
                                   logging=True)

    elif FLAGS.model == 'graphsage_maxpool':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   aggregator_type="maxpool",
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   logging=True)
    elif FLAGS.model == 'graphsage_meanpool':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   aggregator_type="meanpool",
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   logging=True)

    elif FLAGS.model == 'n2v':
        model = Node2VecModel(
            placeholders,
            features.shape[0],
            minibatch.deg,
            #2x because graphsage uses concat
            nodevec_dim=2 * FLAGS.dim_1,
            lr=FLAGS.learning_rate)
    else:
        raise Exception('Error: model name unrecognized.')

    config = tf.ConfigProto(log_device_placement=FLAGS.log_device_placement)
    config.gpu_options.allow_growth = True
    #config.gpu_options.per_process_gpu_memory_fraction = GPU_MEM_FRACTION
    config.allow_soft_placement = True

    # default to saving all variables : added by wy
    saver = tf.train.Saver()

    # Initialize session
    sess = tf.Session(config=config)
    merged = tf.summary.merge_all()
    summary_writer = tf.summary.FileWriter(log_dir(), sess.graph)

    # Init variables
    sess.run(tf.global_variables_initializer(),
             feed_dict={adj_info_ph: minibatch.adj})

    # Train model
    if FLAGS.isTrain:
        print("== Training Model ==")
        train_shadow_mrr = None
        shadow_mrr = None

        total_steps = 0
        avg_time = 0.0
        epoch_val_costs = []

        train_adj_info = tf.assign(adj_info, minibatch.adj)
        val_adj_info = tf.assign(adj_info, minibatch.test_adj)
        for epoch in range(FLAGS.epochs):
            minibatch.shuffle()

            iter = 0
            print('Epoch: %04d' % (epoch + 1))
            epoch_val_costs.append(0)
            while not minibatch.end():
                # Construct feed dictionary
                feed_dict = minibatch.next_minibatch_feed_dict()
                feed_dict.update({placeholders['dropout']: FLAGS.dropout})

                t = time.time()
                # Training step
                outs = sess.run([
                    merged, model.opt_op, model.loss, model.ranks,
                    model.aff_all, model.mrr, model.outputs1
                ],
                                feed_dict=feed_dict)
                train_cost = outs[2]
                train_mrr = outs[5]
                if train_shadow_mrr is None:
                    train_shadow_mrr = train_mrr  #
                else:
                    train_shadow_mrr -= (1 - 0.99) * (train_shadow_mrr -
                                                      train_mrr)

                if iter % FLAGS.validate_iter == 0:
                    # Validation
                    sess.run(val_adj_info.op)
                    val_cost, ranks, val_mrr, duration = evaluate(
                        sess, model, minibatch, size=FLAGS.validate_batch_size)
                    sess.run(train_adj_info.op)
                    epoch_val_costs[-1] += val_cost
                if shadow_mrr is None:
                    shadow_mrr = val_mrr
                else:
                    shadow_mrr -= (1 - 0.99) * (shadow_mrr - val_mrr)

                if total_steps % FLAGS.print_every == 0:
                    summary_writer.add_summary(outs[0], total_steps)

                # Print results
                avg_time = (avg_time * total_steps + time.time() -
                            t) / (total_steps + 1)

                if total_steps % FLAGS.print_every == 0:
                    print(
                        "Iter:",
                        '%04d' % iter,
                        "train_loss=",
                        "{:.5f}".format(train_cost),
                        "train_mrr=",
                        "{:.5f}".format(train_mrr),
                        "train_mrr_ema=",
                        "{:.5f}".format(
                            train_shadow_mrr),  # exponential moving average
                        "val_loss=",
                        "{:.5f}".format(val_cost),
                        "val_mrr=",
                        "{:.5f}".format(val_mrr),
                        "val_mrr_ema=",
                        "{:.5f}".format(
                            shadow_mrr),  # exponential moving average
                        "time=",
                        "{:.5f}".format(avg_time))

                    iter += 1
                total_steps += 1

                if total_steps > FLAGS.max_total_steps:
                    break

            if total_steps > FLAGS.max_total_steps:
                break
        print("Optimization Finished!")
        # save model
        saver.save(sess, log_dir() + 'model.ckpt')
        print("save model succeed!")
    else:
        print("==Testing Model==")
        ckpt = tf.train.get_checkpoint_state(log_dir())
        print(log_dir())
        print(ckpt)
        print(ckpt.model_ckeckpoint_path)
        raw_input()
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_ckeckpoint_path)
        else:
            pass

    raw_input()
    if FLAGS.save_embeddings:
        sess.run(val_adj_info.op)

        save_val_embeddings(sess, model, minibatch, FLAGS.validate_batch_size,
                            log_dir())

        if FLAGS.model == "n2v":
            # stopping the gradient for the already trained nodes
            train_ids = tf.constant(
                [[id_map[n]] for n in G.nodes_iter()
                 if not G.node[n]['val'] and not G.node[n]['test']],
                dtype=tf.int32)
            test_ids = tf.constant([[id_map[n]] for n in G.nodes_iter()
                                    if G.node[n]['val'] or G.node[n]['test']],
                                   dtype=tf.int32)
            update_nodes = tf.nn.embedding_lookup(model.context_embeds,
                                                  tf.squeeze(test_ids))
            no_update_nodes = tf.nn.embedding_lookup(model.context_embeds,
                                                     tf.squeeze(train_ids))
            update_nodes = tf.scatter_nd(test_ids, update_nodes,
                                         tf.shape(model.context_embeds))
            no_update_nodes = tf.stop_gradient(
                tf.scatter_nd(train_ids, no_update_nodes,
                              tf.shape(model.context_embeds)))
            model.context_embeds = update_nodes + no_update_nodes
            sess.run(model.context_embeds)

            # run random walks
            from graphsage.utils import run_random_walks
            nodes = [
                n for n in G.nodes_iter()
                if G.node[n]["val"] or G.node[n]["test"]
            ]
            start_time = time.time()
            pairs = run_random_walks(G, nodes, num_walks=50)
            walk_time = time.time() - start_time

            test_minibatch = EdgeMinibatchIterator(
                G,
                id_map,
                placeholders,
                batch_size=FLAGS.batch_size,
                max_degree=FLAGS.max_degree,
                num_neg_samples=FLAGS.neg_sample_size,
                context_pairs=pairs,
                n2v_retrain=True,
                fixed_n2v=True)

            start_time = time.time()
            print("Doing test training for n2v.")
            test_steps = 0
            for epoch in range(FLAGS.n2v_test_epochs):
                test_minibatch.shuffle()
                while not test_minibatch.end():
                    feed_dict = test_minibatch.next_minibatch_feed_dict()
                    feed_dict.update({placeholders['dropout']: FLAGS.dropout})
                    outs = sess.run([
                        model.opt_op, model.loss, model.ranks, model.aff_all,
                        model.mrr, model.outputs1
                    ],
                                    feed_dict=feed_dict)
                    if test_steps % FLAGS.print_every == 0:
                        print("Iter:", '%04d' % test_steps, "train_loss=",
                              "{:.5f}".format(outs[1]), "train_mrr=",
                              "{:.5f}".format(outs[-2]))
                    test_steps += 1
            train_time = time.time() - start_time
            save_val_embeddings(sess,
                                model,
                                minibatch,
                                FLAGS.validate_batch_size,
                                log_dir(),
                                mod="-test")
            print("Total time: ", train_time + walk_time)
            print("Walk time: ", walk_time)
            print("Train time: ", train_time)
Ejemplo n.º 4
0
def train(train_data, test_data=None):
    G = train_data[0]
    features = train_data[1]
    id_map = train_data[2]

    if not features is None:
        # pad with dummy zero vector
        features = np.vstack([features, np.zeros((features.shape[1], ))])

    context_pairs = train_data[3] if FLAGS.random_context else None
    placeholders = construct_placeholders()
    minibatch = EdgeMinibatchIterator(G,
                                      id_map,
                                      placeholders,
                                      batch_size=FLAGS.batch_size,
                                      max_degree=FLAGS.max_degree,
                                      num_neg_samples=FLAGS.neg_sample_size,
                                      context_pairs=context_pairs)
    adj_info_ph = tf.compat.v1.placeholder(tf.int32, shape=minibatch.adj.shape)
    adj_info = tf.Variable(adj_info_ph, trainable=False, name="adj_info")

    if FLAGS.model == 'graphsage_mean':
        # Create model
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   logging=True)
    elif FLAGS.model == 'gcn':
        # Create model
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, 2 * FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, 2 * FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   aggregator_type="gcn",
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   concat=False,
                                   logging=True)

    elif FLAGS.model == 'graphsage_seq':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   identity_dim=FLAGS.identity_dim,
                                   aggregator_type="seq",
                                   model_size=FLAGS.model_size,
                                   logging=True)

    elif FLAGS.model == 'graphsage_maxpool':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   aggregator_type="maxpool",
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   logging=True)
    elif FLAGS.model == 'graphsage_meanpool':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   aggregator_type="meanpool",
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   logging=True)

    elif FLAGS.model == 'n2v':
        model = Node2VecModel(
            placeholders,
            features.shape[0],
            minibatch.deg,
            #2x because graphsage uses concat
            nodevec_dim=2 * FLAGS.dim_1,
            lr=FLAGS.learning_rate)
    else:
        raise Exception('Error: model name unrecognized.')

    config = tf.compat.v1.ConfigProto(
        log_device_placement=FLAGS.log_device_placement)
    config.gpu_options.allow_growth = True
    #config.gpu_options.per_process_gpu_memory_fraction = GPU_MEM_FRACTION
    config.allow_soft_placement = True

    # Initialize WandB experiment
    wandb.init(project='GraphSAGE_trial',
               entity='cvl-liu-01',
               save_code=True,
               tags=['unsupervised'])
    wandb.config.update(flags.FLAGS)

    # Initialize session
    sess = tf.compat.v1.Session(config=config)
    merged = tf.compat.v1.summary.merge_all()
    summary_writer = tf.compat.v1.summary.FileWriter(log_dir(), sess.graph)

    # Init variables
    sess.run(tf.compat.v1.global_variables_initializer(),
             feed_dict={adj_info_ph: minibatch.adj})

    # Init saver
    saver = tf.compat.v1.train.Saver(max_to_keep=8,
                                     keep_checkpoint_every_n_hours=1)

    # Train model
    train_shadow_mrr = None
    val_shadow_mrr = None

    total_steps = 0
    avg_time = 0.0
    epoch_val_costs = []

    train_adj_info = tf.compat.v1.assign(adj_info, minibatch.adj)
    val_adj_info = tf.compat.v1.assign(adj_info, minibatch.test_adj)
    for epoch in range(FLAGS.epochs):
        minibatch.shuffle()

        iter = 0
        print('Epoch: %04d' % (epoch + 1))
        epoch_val_costs.append(0)
        while not minibatch.end():
            # Construct feed dictionary
            feed_dict = minibatch.next_minibatch_feed_dict()
            feed_dict.update({placeholders['dropout']: FLAGS.dropout})

            t = time.time()
            # Training step
            outs = sess.run([
                merged, model.opt_op, model.loss, model.ranks, model.aff_all,
                model.mrr, model.outputs1
            ],
                            feed_dict=feed_dict)
            train_cost = outs[2]
            train_mrr = outs[5]
            if train_shadow_mrr is None:
                train_shadow_mrr = train_mrr  #
            else:
                train_shadow_mrr -= (1 - 0.99) * (train_shadow_mrr - train_mrr)

            # Validation
            if iter % FLAGS.validate_iter == 0:
                sess.run(val_adj_info.op)
                val_cost, ranks, val_mrr, duration = evaluate(
                    sess, model, minibatch, size=FLAGS.validate_batch_size)
                sess.run(train_adj_info.op)
                epoch_val_costs[-1] += val_cost
            if val_shadow_mrr is None:
                val_shadow_mrr = val_mrr
            else:
                val_shadow_mrr -= (1 - 0.99) * (val_shadow_mrr - val_mrr)

            if total_steps % FLAGS.print_every == 0:
                summary_writer.add_summary(outs[0], total_steps)

            # Print results
            avg_time = (avg_time * total_steps + time.time() -
                        t) / (total_steps + 1)

            if total_steps % FLAGS.print_every == 0:
                print(
                    "[%03d/%03d]" % (epoch + 1, FLAGS.epochs),
                    "Iter:",
                    '[%05d/%05d]' % (iter, minibatch.num_training_batches()),
                    "train_loss =",
                    "{:.5f}".format(train_cost),
                    "train_mrr =",
                    "{:.5f}".format(train_mrr),
                    "train_mrr_ema =",
                    "{:.5f}".format(
                        train_shadow_mrr),  # exponential moving average
                    "val_loss =",
                    "{:.5f}".format(val_cost),
                    "val_mrr =",
                    "{:.5f}".format(val_mrr),
                    "val_mrr_ema =",
                    "{:.5f}".format(
                        val_shadow_mrr),  # exponential moving average
                    "time =",
                    "{:.5f}".format(avg_time))

            # W&B Logging
            if FLAGS.wandb_log and iter % FLAGS.wandb_log_iter == 0:
                wandb.log({'train_loss': train_cost, 'epoch': epoch})
                wandb.log({'train_mrr': train_mrr, 'epoch': epoch})
                wandb.log({'train_mrr_ema': train_shadow_mrr, 'epoch': epoch})
                wandb.log({'val_loss': val_cost, 'epoch': epoch})
                wandb.log({'val_mrr': val_mrr, 'epoch': epoch})
                wandb.log({'val_mrr_ema': val_shadow_mrr, 'epoch': epoch})
                wandb.log({'time': avg_time, 'epoch': epoch})

            iter += 1
            total_steps += 1

            if total_steps > FLAGS.max_total_steps:
                print('Max total steps reached!')
                break

        # Save embeddings
        if FLAGS.save_embeddings and epoch % FLAGS.save_embeddings_epoch == 0:
            save_val_embeddings(sess, model, minibatch,
                                FLAGS.validate_batch_size, log_dir())

            # Also report classifier metric on the embedding
            all_tr_res, all_ts_res = osm_classif_eval.evaluate(
                FLAGS.train_prefix, log_dir(), n_iter=FLAGS.classif_n_iter)
            if FLAGS.wandb_log:
                wandb.log(all_tr_res)
                wandb.log(all_ts_res)

        # Save Model checkpoints
        if FLAGS.save_checkpoints and epoch % FLAGS.save_checkpoints_epoch == 0:
            # saver.save(sess, log_dir() + 'model', global_step=1000)
            print('Save model checkpoint:', wandb.run.dir, iter, total_steps,
                  epoch)
            saver.save(
                sess,
                os.path.join(wandb.run.dir,
                             "model-" + str(epoch + 1) + ".ckpt"))

        if total_steps > FLAGS.max_total_steps:
            print('Max total steps reached!')
            break

    print("Optimization Finished!")
    if FLAGS.save_embeddings:
        sess.run(val_adj_info.op)

        save_val_embeddings(sess, model, minibatch, FLAGS.validate_batch_size,
                            log_dir())

        if FLAGS.model == "n2v":
            # stopping the gradient for the already trained nodes
            train_ids = tf.constant(
                [[id_map[n]] for n in G.nodes_iter()
                 if not G.node[n]['val'] and not G.node[n]['test']],
                dtype=tf.int32)
            test_ids = tf.constant([[id_map[n]] for n in G.nodes_iter()
                                    if G.node[n]['val'] or G.node[n]['test']],
                                   dtype=tf.int32)
            update_nodes = tf.nn.embedding_lookup(model.context_embeds,
                                                  tf.squeeze(test_ids))
            no_update_nodes = tf.nn.embedding_lookup(model.context_embeds,
                                                     tf.squeeze(train_ids))
            update_nodes = tf.scatter_nd(test_ids, update_nodes,
                                         tf.shape(model.context_embeds))
            no_update_nodes = tf.stop_gradient(
                tf.scatter_nd(train_ids, no_update_nodes,
                              tf.shape(model.context_embeds)))
            model.context_embeds = update_nodes + no_update_nodes
            sess.run(model.context_embeds)

            # run random walks
            from graphsage.utils import run_random_walks
            nodes = [
                n for n in G.nodes_iter()
                if G.node[n]["val"] or G.node[n]["test"]
            ]
            start_time = time.time()
            pairs = run_random_walks(G, nodes, num_walks=50)
            walk_time = time.time() - start_time

            test_minibatch = EdgeMinibatchIterator(
                G,
                id_map,
                placeholders,
                batch_size=FLAGS.batch_size,
                max_degree=FLAGS.max_degree,
                num_neg_samples=FLAGS.neg_sample_size,
                context_pairs=pairs,
                n2v_retrain=True,
                fixed_n2v=True)

            start_time = time.time()
            print("Doing test training for n2v.")
            test_steps = 0
            for epoch in range(FLAGS.n2v_test_epochs):
                test_minibatch.shuffle()
                while not test_minibatch.end():
                    feed_dict = test_minibatch.next_minibatch_feed_dict()
                    feed_dict.update({placeholders['dropout']: FLAGS.dropout})
                    outs = sess.run([
                        model.opt_op, model.loss, model.ranks, model.aff_all,
                        model.mrr, model.outputs1
                    ],
                                    feed_dict=feed_dict)
                    if test_steps % FLAGS.print_every == 0:
                        print("Iter:", '%04d' % test_steps, "train_loss=",
                              "{:.5f}".format(outs[1]), "train_mrr=",
                              "{:.5f}".format(outs[-2]))
                    test_steps += 1
            train_time = time.time() - start_time
            save_val_embeddings(sess,
                                model,
                                minibatch,
                                FLAGS.validate_batch_size,
                                log_dir(),
                                mod="-test")
            print("Total time: ", train_time + walk_time)
            print("Walk time: ", walk_time)
            print("Train time: ", train_time)
Ejemplo n.º 5
0
def dynamic_test(init_file, dynamic_file, feats, id_map, flag_file, params,
                 metric, draw):
    [G, walks, class_map, node_flag, flag_no] = init_G(init_file, flag_file,
                                                       id_map, feats)
    train([G, feats, id_map, walks, class_map], params, G.number_of_nodes(),
          metric, draw)
    change_G_status(G)
    print("first test finished!, enter to continue")
    node_edges_lst = []
    with open(dynamic_file, 'r') as infile:
        while True:
            line = infile.readline()
            if not line:
                break
            items = line.strip().split()
            if len(items) != 2:
                continue
            edges_added = []
            for i in range(int(items[1])):
                line = infile.readline()
                items = line.strip().split()
                edges_added.append([int(it) for it in items])
            node_edges_lst.append(edges_added)

    edges_added = []
    none_line = 0
    start_time = datetime.datetime.now()
    for i in range(len(node_edges_lst)):
        # add edges
        none_line += 1
        edges_added += node_edges_lst[i]

        if none_line == FLAGS.test_batch_size:
            update_G(G, feats, id_map, class_map, edges_added, node_flag,
                     flag_no)
            # construct walks
            nodes = [
                n for n in G.nodes()
                if not G.node[n]['val'] and not G.node[n]['test']
            ]
            G_part = G.subgraph(nodes)
            walks = run_random_walks(G_part, nodes)
            # train model
            train([G, feats, id_map, walks, class_map], params,
                  G.number_of_nodes(), metric, draw)
            change_G_status(G)
            print("update status: " + str(G.number_of_nodes()) +
                  "nodes, enter to continue")

            end_time = datetime.datetime.now()
            dh.append_to_file(params['output_path'] + "_time",
                              str(end_time - start_time) + "\n")
            start_time = end_time
            none_line = 0
            edges_added = []

    if len(edges_added):
        update_G(G, feats, id_map, class_map, edges_added, node_flag, flag_no)
        # construct walks
        nodes = [
            n for n in G.nodes()
            if not G.node[n]['val'] and not G.node[n]['test']
        ]
        G_part = G.subgraph(nodes)
        walks = run_random_walks(G_part, nodes)
        # train model
        train([G, feats, id_map, walks, class_map], params,
              G.number_of_nodes(), metric, draw)
        change_G_status(G)
        end_time = datetime.datetime.now()
        dh.append_to_file(params['output_path'] + "_time",
                          str(end_time - start_time) + "\n")
        print("update status: " + str(G.number_of_nodes()) +
              "nodes, enter to continue")
Ejemplo n.º 6
0
def inference(train_data, test_data=None):
    '''
    function giúp embed new node
    :param train_data: 
    :param test_data: 
    :return: 
    '''
    G = train_data[0]
    features = train_data[1]
    id_map = train_data[2]

    if not features is None:
        # pad with dummy zero vector
        features = np.vstack([features, np.zeros((features.shape[1],))])

    context_pairs = train_data[3] if FLAGS.random_context else None
    placeholders = construct_placeholders()
    minibatch = EdgeMinibatchIterator(G,
                                      id_map,
                                      placeholders, batch_size=FLAGS.batch_size,
                                      max_degree=FLAGS.max_degree,
                                      num_neg_samples=FLAGS.neg_sample_size,
                                      context_pairs=context_pairs)
    adj_info = tf.Variable(tf.constant(minibatch.adj, dtype=tf.int32), trainable=False, name="adj_info")

    if FLAGS.model == 'graphsage_mean':
        # Create model
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                       SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   logging=True)
    elif FLAGS.model == 'gcn':
        # Create model
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, 2 * FLAGS.dim_1),
                       SAGEInfo("node", sampler, FLAGS.samples_2, 2 * FLAGS.dim_2)]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   aggregator_type="gcn",
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   concat=False,
                                   logging=True)

    elif FLAGS.model == 'graphsage_seq':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                       SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   identity_dim=FLAGS.identity_dim,
                                   aggregator_type="seq",
                                   model_size=FLAGS.model_size,
                                   logging=True)

    elif FLAGS.model == 'graphsage_maxpool':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                       SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   aggregator_type="maxpool",
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   logging=True)
    elif FLAGS.model == 'graphsage_meanpool':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                       SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   aggregator_type="meanpool",
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   logging=True)

    elif FLAGS.model == 'n2v':
        model = Node2VecModel(placeholders, features.shape[0],
                              minibatch.deg,
                              # 2x because graphsage uses concat
                              nodevec_dim=2 * FLAGS.dim_1,
                              lr=FLAGS.learning_rate)
    else:
        raise Exception('Error: model name unrecognized.')

    config = tf.ConfigProto(log_device_placement=FLAGS.log_device_placement)
    config.gpu_options.allow_growth = True
    # config.gpu_options.per_process_gpu_memory_fraction = GPU_MEM_FRACTION
    config.allow_soft_placement = True

    # Initialize session
    sess = tf.Session(config=config)
    # merged = tf.summary.merge_all()
    # summary_writer = tf.summary.FileWriter(log_dir(), sess.graph)

    #save model
    saver = tf.train.Saver()
    # Init variables
    sess.run(tf.global_variables_initializer())

    # Train model

    train_shadow_mrr = None
    shadow_mrr = None

    total_steps = 0
    avg_time = 0.0
    epoch_val_costs = []

    train_adj_info = tf.assign(adj_info, minibatch.adj)
    val_adj_info = tf.assign(adj_info, minibatch.test_adj)
    # for epoch in range(FLAGS.epochs):
    #     minibatch.shuffle()
    #
    #     iter = 0
    #     print('Epoch: %04d' % (epoch + 1))
    #     epoch_val_costs.append(0)
    #     while not minibatch.end():
    #         # Construct feed dictionary
    #         feed_dict = minibatch.next_minibatch_feed_dict()
    #         feed_dict.update({placeholders['dropout']: FLAGS.dropout})
    #
    #         t = time.time()
    #         # Training step
    #         outs = sess.run([merged, model.opt_op, model.loss, model.ranks, model.aff_all,
    #                          model.mrr, model.outputs1], feed_dict=feed_dict)
    #         train_cost = outs[2]
    #         train_mrr = outs[5]
    #         if train_shadow_mrr is None:
    #             train_shadow_mrr = train_mrr  #
    #         else:
    #             train_shadow_mrr -= (1 - 0.99) * (train_shadow_mrr - train_mrr)
    #
    #         if iter % FLAGS.validate_iter == 0:
    #             # Validation
    #             sess.run(val_adj_info.op)
    #             val_cost, ranks, val_mrr, duration = evaluate(sess, model, minibatch, size=FLAGS.validate_batch_size)
    #             sess.run(train_adj_info.op)
    #             epoch_val_costs[-1] += val_cost
    #         if shadow_mrr is None:
    #             shadow_mrr = val_mrr
    #         else:
    #             shadow_mrr -= (1 - 0.99) * (shadow_mrr - val_mrr)
    #
    #         if total_steps % FLAGS.print_every == 0:
    #             summary_writer.add_summary(outs[0], total_steps)
    #
    #         # Print results
    #         avg_time = (avg_time * total_steps + time.time() - t) / (total_steps + 1)
    #
    #         if total_steps % FLAGS.print_every == 0:
    #             print("Iter:", '%04d' % iter,
    #                   "train_loss=", "{:.5f}".format(train_cost),
    #                   "train_mrr=", "{:.5f}".format(train_mrr),
    #                   "train_mrr_ema=", "{:.5f}".format(train_shadow_mrr),  # exponential moving average
    #                   "val_loss=", "{:.5f}".format(val_cost),
    #                   "val_mrr=", "{:.5f}".format(val_mrr),
    #                   "val_mrr_ema=", "{:.5f}".format(shadow_mrr),  # exponential moving average
    #                   "time=", "{:.5f}".format(avg_time))
    #
    #         iter += 1
    #         total_steps += 1
    #
    #         if total_steps > FLAGS.max_total_steps:
    #             break
    #
    #     if total_steps > FLAGS.max_total_steps:
    #         break

    # restore model
    file = r"C:\Users\Windows 10 TIMT\OneDrive\Nam\OneDrive - Five9 Vietnam Corporation\work\learn_five9\GraphSAGE\graphsage\unsup-facebook\graphsage_mean_small_0.000010\\final model"
    ckpt = tf.train.get_checkpoint_state(file)
    if ckpt and ckpt.model_checkpoint_path:
        saver.restore(sess, ckpt.model_checkpoint_path)

    # saver.restore(sess,os.path.join(log_dir(), "model.ckpt"))
    # saver.restore(sess,)
    print("Model restored.")
    if FLAGS.save_embeddings:
        # sess.run(val_adj_info.op)
        save_val_embeddings(sess, model, minibatch, FLAGS.validate_batch_size, log_dir())

        if FLAGS.model == "n2v":
            # stopping the gradient for the already trained nodes
            train_ids = tf.constant(
                [[id_map[n]] for n in G.nodes_iter() if not G.node[n]['val'] and not G.node[n]['test']],
                dtype=tf.int32)
            test_ids = tf.constant([[id_map[n]] for n in G.nodes_iter() if G.node[n]['val'] or G.node[n]['test']],
                                   dtype=tf.int32)
            update_nodes = tf.nn.embedding_lookup(model.context_embeds, tf.squeeze(test_ids))
            no_update_nodes = tf.nn.embedding_lookup(model.context_embeds, tf.squeeze(train_ids))
            update_nodes = tf.scatter_nd(test_ids, update_nodes, tf.shape(model.context_embeds))
            no_update_nodes = tf.stop_gradient(
                tf.scatter_nd(train_ids, no_update_nodes, tf.shape(model.context_embeds)))
            model.context_embeds = update_nodes + no_update_nodes
            sess.run(model.context_embeds)

            # run random walks
            from graphsage.utils import run_random_walks
            nodes = [n for n in G.nodes_iter() if G.node[n]["val"] or G.node[n]["test"]]
            start_time = time.time()
            pairs = run_random_walks(G, nodes, num_walks=50)
            walk_time = time.time() - start_time

            test_minibatch = EdgeMinibatchIterator(G,
                                                   id_map,
                                                   placeholders, batch_size=FLAGS.batch_size,
                                                   max_degree=FLAGS.max_degree,
                                                   num_neg_samples=FLAGS.neg_sample_size,
                                                   context_pairs=pairs,
                                                   n2v_retrain=True,
                                                   fixed_n2v=True)

            start_time = time.time()
            print("Doing test training for n2v.")
            test_steps = 0
            for epoch in range(FLAGS.n2v_test_epochs):
                test_minibatch.shuffle()
                while not test_minibatch.end():
                    feed_dict = test_minibatch.next_minibatch_feed_dict()
                    feed_dict.update({placeholders['dropout']: FLAGS.dropout})
                    outs = sess.run([model.opt_op, model.loss, model.ranks, model.aff_all,
                                     model.mrr, model.outputs1], feed_dict=feed_dict)
                    if test_steps % FLAGS.print_every == 0:
                        print("Iter:", '%04d' % test_steps,
                              "train_loss=", "{:.5f}".format(outs[1]),
                              "train_mrr=", "{:.5f}".format(outs[-2]))
                    test_steps += 1
            train_time = time.time() - start_time
            save_val_embeddings(sess, model, minibatch, FLAGS.validate_batch_size, log_dir(), mod="-test")
            print("Total time: ", train_time + walk_time)
            print("Walk time: ", walk_time)
            print("Train time: ", train_time)
Ejemplo n.º 7
0
def train(train_data, action, test_data=None, regress_fun=run_regression):
    config = tf.ConfigProto(log_device_placement=FLAGS.log_device_placement)
    config.gpu_options.allow_growth = True
    # config.gpu_options.per_process_gpu_memory_fraction = GPU_MEM_FRACTION
    config.allow_soft_placement = True
    # Initialize session
    # 定义作用域,不然会与Controller发生冲突
    with tf.Session(config=config, graph=tf.Graph()) as sess:
        with sess.as_default():
            with sess.graph.as_default():
                # Set random seed
                seed = 123
                np.random.seed(seed)
                tf.set_random_seed(seed)

                G = train_data[0]
                features = train_data[1]
                id_map = train_data[2]

                if not features is None:
                    # pad with dummy zero vector
                    features = np.vstack(
                        [features, np.zeros((features.shape[1], ))])

                context_pairs = train_data[3] if FLAGS.random_context else None
                placeholders = construct_placeholders()
                minibatch = EdgeMinibatchIterator(
                    G,
                    id_map,
                    placeholders,
                    batch_size=FLAGS.batch_size,
                    max_degree=FLAGS.max_degree,
                    num_neg_samples=FLAGS.neg_sample_size,
                    context_pairs=context_pairs)
                adj_info_ph = tf.placeholder(tf.int32,
                                             shape=minibatch.adj.shape)
                adj_info = tf.Variable(adj_info_ph,
                                       trainable=False,
                                       name="adj_info")

                sampler = UniformNeighborSampler(adj_info)  # 邻居采样,方式为随机重排邻居
                state_nums = 2  # Controller定义的状态数量
                layers_num = len(action) // state_nums  # 计算层数
                layer_infos = []
                # 用于指导最终GNN的生成
                layer_infos.append(
                    SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1))
                layer_infos.append(
                    SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_1))
                # 用于NAS的无监督GraphSage
                model = NASUnsupervisedGraphsage(
                    placeholders,
                    features,
                    adj_info,
                    minibatch.deg,
                    layer_infos=layer_infos,
                    action=action,
                    state_nums=state_nums,
                    model_size=FLAGS.model_size,
                    identity_dim=FLAGS.identity_dim,
                    logging=True)

                merged = tf.summary.merge_all()
                summary_writer = tf.summary.FileWriter(log_dir(action),
                                                       sess.graph)

                # Init variables
                sess.run(tf.global_variables_initializer(),
                         feed_dict={adj_info_ph: minibatch.adj})

                # Train model

                train_shadow_mrr = None
                shadow_mrr = None

                total_steps = 0
                avg_time = 0.0
                epoch_val_costs = []

                train_adj_info = tf.assign(adj_info, minibatch.adj)
                val_adj_info = tf.assign(adj_info, minibatch.test_adj)
                for epoch in range(FLAGS.epochs):
                    minibatch.shuffle()

                    iter = 0
                    print('Epoch: %04d' % (epoch + 1))
                    epoch_val_costs.append(0)
                    while not minibatch.end():
                        # Construct feed dictionary
                        feed_dict = minibatch.next_minibatch_feed_dict()
                        feed_dict.update(
                            {placeholders['dropout']: FLAGS.dropout})

                        t = time.time()
                        # Training step
                        outs = sess.run([
                            merged, model.opt_op, model.loss, model.ranks,
                            model.aff_all, model.mrr, model.outputs1
                        ],
                                        feed_dict=feed_dict)
                        train_cost = outs[2]
                        train_mrr = outs[5]
                        if train_shadow_mrr is None:
                            train_shadow_mrr = train_mrr  #
                        else:
                            train_shadow_mrr -= (1 - 0.99) * (
                                train_shadow_mrr - train_mrr)

                        if iter % FLAGS.validate_iter == 0:
                            # Validation
                            sess.run(val_adj_info.op)
                            val_cost, ranks, val_mrr, duration = evaluate(
                                sess,
                                model,
                                minibatch,
                                size=FLAGS.validate_batch_size)
                            sess.run(train_adj_info.op)
                            epoch_val_costs[-1] += val_cost
                        if shadow_mrr is None:
                            shadow_mrr = val_mrr
                        else:
                            shadow_mrr -= (1 - 0.99) * (shadow_mrr - val_mrr)

                        if total_steps % FLAGS.print_every == 0:
                            summary_writer.add_summary(outs[0], total_steps)

                        # Print results
                        avg_time = (avg_time * total_steps + time.time() -
                                    t) / (total_steps + 1)

                        if total_steps % FLAGS.print_every == 0:
                            print(
                                "Iter:",
                                '%04d' % iter,
                                "train_loss=",
                                "{:.5f}".format(train_cost),
                                "train_mrr=",
                                "{:.5f}".format(train_mrr),
                                "train_mrr_ema=",
                                "{:.5f}".format(
                                    train_shadow_mrr
                                ),  # exponential moving average
                                "val_loss=",
                                "{:.5f}".format(val_cost),
                                "val_mrr=",
                                "{:.5f}".format(val_mrr),
                                "val_mrr_ema=",
                                "{:.5f}".format(
                                    shadow_mrr),  # exponential moving average
                                "time=",
                                "{:.5f}".format(avg_time))

                        iter += 1
                        total_steps += 1

                        if total_steps > FLAGS.max_total_steps:
                            break

                    if total_steps > FLAGS.max_total_steps:
                        break

                print("Optimization Finished!")
                sess.run(val_adj_info.op)
                # val_losess,val_mrr, duration = incremental_evaluate(sess, model, minibatch,
                #                                                  FLAGS.batch_size)
                # print("Full validation stats:",
                #       "loss=", "{:.5f}".format(val_losess),
                #       "val_mrr=", "{:.5f}".format(val_mrr),
                #       "time=", "{:.5f}".format(duration))

                if FLAGS.save_embeddings:
                    sess.run(val_adj_info.op)

                    embeds = save_val_embeddings(sess, model, minibatch,
                                                 FLAGS.validate_batch_size,
                                                 log_dir(action))

                    if FLAGS.model == "n2v":
                        # stopping the gradient for the already trained nodes
                        train_ids = tf.constant(
                            [[id_map[n]] for n in G.nodes_iter()
                             if not G.node[n]['val'] and not G.node[n]['test']
                             ],
                            dtype=tf.int32)
                        test_ids = tf.constant(
                            [[id_map[n]] for n in G.nodes_iter()
                             if G.node[n]['val'] or G.node[n]['test']],
                            dtype=tf.int32)
                        update_nodes = tf.nn.embedding_lookup(
                            model.context_embeds, tf.squeeze(test_ids))
                        no_update_nodes = tf.nn.embedding_lookup(
                            model.context_embeds, tf.squeeze(train_ids))
                        update_nodes = tf.scatter_nd(
                            test_ids, update_nodes,
                            tf.shape(model.context_embeds))
                        no_update_nodes = tf.stop_gradient(
                            tf.scatter_nd(train_ids, no_update_nodes,
                                          tf.shape(model.context_embeds)))
                        model.context_embeds = update_nodes + no_update_nodes
                        sess.run(model.context_embeds)

                        # run random walks
                        from graphsage.utils import run_random_walks
                        nodes = [
                            n for n in G.nodes_iter()
                            if G.node[n]["val"] or G.node[n]["test"]
                        ]
                        start_time = time.time()
                        pairs = run_random_walks(G, nodes, num_walks=50)
                        walk_time = time.time() - start_time

                        test_minibatch = EdgeMinibatchIterator(
                            G,
                            id_map,
                            placeholders,
                            batch_size=FLAGS.batch_size,
                            max_degree=FLAGS.max_degree,
                            num_neg_samples=FLAGS.neg_sample_size,
                            context_pairs=pairs,
                            n2v_retrain=True,
                            fixed_n2v=True)

                        start_time = time.time()
                        print("Doing test training for n2v.")
                        test_steps = 0
                        for epoch in range(FLAGS.n2v_test_epochs):
                            test_minibatch.shuffle()
                            while not test_minibatch.end():
                                feed_dict = test_minibatch.next_minibatch_feed_dict(
                                )
                                feed_dict.update(
                                    {placeholders['dropout']: FLAGS.dropout})
                                outs = sess.run([
                                    model.opt_op, model.loss, model.ranks,
                                    model.aff_all, model.mrr, model.outputs1
                                ],
                                                feed_dict=feed_dict)
                                if test_steps % FLAGS.print_every == 0:
                                    print("Iter:", '%04d' % test_steps,
                                          "train_loss=",
                                          "{:.5f}".format(outs[1]),
                                          "train_mrr=",
                                          "{:.5f}".format(outs[-2]))
                                test_steps += 1
                        train_time = time.time() - start_time
                        save_val_embeddings(sess,
                                            model,
                                            minibatch,
                                            FLAGS.validate_batch_size,
                                            log_dir(action),
                                            mod="-test")
                        print("Total time: ", train_time + walk_time)
                        print("Walk time: ", walk_time)
                        print("Train time: ", train_time)
    setting = "val"
    tf.reset_default_graph()
    labels = train_data[4]
    train_ids = [
        n for n in G.nodes() if not G.node[n]['val'] and not G.node[n]['test']
    ]
    test_ids = [n for n in G.nodes() if G.node[n][setting]]
    train_labels = np.array([labels[i] for i in train_ids])
    if train_labels.ndim == 1:
        train_labels = np.expand_dims(train_labels, 1)
    test_labels = np.array([labels[i] for i in test_ids])
    id_map = {}
    with open(log_dir(action) + "/val.txt") as fp:
        for i, line in enumerate(fp):
            id_map[line.strip()] = i
    train_embeds = embeds[[id_map[str(id)] for id in train_ids]]
    test_embeds = embeds[[id_map[str(id)] for id in test_ids]]

    print("Running regression..")
    regress_fun(train_embeds, train_labels, test_embeds, test_labels)
    #用f1指数替换accuracy,此处未做滑动指数平均
    return get_rewards(val_mrr), val_mrr
Ejemplo n.º 8
0
def train(train_data, test_data=None, sampler_name='Uniform'):
    
    G = train_data[0]
    features = train_data[1]
    id_map = train_data[2]

    if not features is None:
        # pad with dummy zero vector
        features = np.vstack([features, np.zeros((features.shape[1],))])

    context_pairs = train_data[3] if FLAGS.random_context else None
    placeholders = construct_placeholders()
    minibatch = EdgeMinibatchIterator(G, 
            id_map,
            placeholders, batch_size=FLAGS.batch_size,
            max_degree=FLAGS.max_degree, 
            num_neg_samples=FLAGS.neg_sample_size,
            context_pairs = context_pairs)
    
    adj_info_ph = tf.placeholder(tf.int32, shape=minibatch.adj.shape)
    adj_info = tf.Variable(adj_info_ph, trainable=False, name="adj_info")
    adj_shape = adj_info.get_shape().as_list()

    if FLAGS.model == 'mean_concat':
        # Create model

        if sampler_name == 'Uniform':
            sampler = UniformNeighborSampler(adj_info)
        elif sampler_name == 'ML':
            sampler = MLNeighborSampler(adj_info, features)
        elif sampler_name == 'FastML':
            sampler = FastMLNeighborSampler(adj_info, features)


        #sampler = UniformNeighborSampler(adj_info)
        layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)]

        model = SampleAndAggregate(placeholders, 
                                     features,
                                     adj_info,
                                     minibatch.deg,
                                     concat=True,
                                     layer_infos=layer_infos, 
                                     model_size=FLAGS.model_size,
                                     identity_dim = FLAGS.identity_dim,
                                     logging=True)
    
    elif FLAGS.model == 'mean_add':
        # Create model

        if sampler_name == 'Uniform':
            sampler = UniformNeighborSampler(adj_info)
        elif sampler_name == 'ML':
            sampler = MLNeighborSampler(adj_info, features)
        elif sampler_name == 'FastML':
            sampler = FastMLNeighborSampler(adj_info, features)


        #sampler = UniformNeighborSampler(adj_info)
        layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)]

        model = SampleAndAggregate(placeholders, 
                                     features,
                                     adj_info,
                                     minibatch.deg,
                                     concat=False,
                                     layer_infos=layer_infos, 
                                     model_size=FLAGS.model_size,
                                     identity_dim = FLAGS.identity_dim,
                                     logging=True)

    elif FLAGS.model == 'gcn':

        if sampler_name == 'Uniform':
            sampler = UniformNeighborSampler(adj_info)
        elif sampler_name == 'ML':
            sampler = MLNeighborSampler(adj_info, features)
        elif sampler_name == 'FastML':
            sampler = FastMLNeighborSampler(adj_info, features)

           
        # Create model
        #sampler = UniformNeighborSampler(adj_info)
        layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, 2*FLAGS.dim_1),
                            SAGEInfo("node", sampler, FLAGS.samples_2, 2*FLAGS.dim_2)]

        model = SampleAndAggregate(placeholders, 
                                     features,
                                     adj_info,
                                     minibatch.deg,
                                     layer_infos=layer_infos, 
                                     aggregator_type="gcn",
                                     model_size=FLAGS.model_size,
                                     identity_dim = FLAGS.identity_dim,
                                     concat=False,
                                     logging=True)

    elif FLAGS.model == 'graphsage_seq':
        
        if sampler_name == 'Uniform':
            sampler = UniformNeighborSampler(adj_info)
        elif sampler_name == 'ML':
            sampler = MLNeighborSampler(adj_info, features)
        elif sampler_name == 'FastML':
            sampler = FastMLNeighborSampler(adj_info, features)


        #sampler = UniformNeighborSampler(adj_info)
        layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)]

        model = SampleAndAggregate(placeholders, 
                                     features,
                                     adj_info,
                                     minibatch.deg,
                                     layer_infos=layer_infos, 
                                     identity_dim = FLAGS.identity_dim,
                                     aggregator_type="seq",
                                     model_size=FLAGS.model_size,
                                     logging=True)

    elif FLAGS.model == 'graphsage_maxpool':
        
        if sampler_name == 'Uniform':
            sampler = UniformNeighborSampler(adj_info)
        elif sampler_name == 'ML':
            sampler = MLNeighborSampler(adj_info, features)
        elif sampler_name == 'FastML':
            sampler = FastMLNeighborSampler(adj_info, features)
    
            
        #sampler = UniformNeighborSampler(adj_info)
        layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)]

        model = SampleAndAggregate(placeholders, 
                                    features,
                                    adj_info,
                                    minibatch.deg,
                                     layer_infos=layer_infos, 
                                     aggregator_type="maxpool",
                                     model_size=FLAGS.model_size,
                                     identity_dim = FLAGS.identity_dim,
                                     logging=True)
    elif FLAGS.model == 'graphsage_meanpool':
        
        if sampler_name == 'Uniform':
            sampler = UniformNeighborSampler(adj_info)
        elif sampler_name == 'ML':
            sampler = MLNeighborSampler(adj_info, features)
        elif sampler_name == 'FastML':
            sampler = FastMLNeighborSampler(adj_info, features)
            
        #sampler = UniformNeighborSampler(adj_info)
        layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)]

        model = SampleAndAggregate(placeholders, 
                                    features,
                                    adj_info,
                                    minibatch.deg,
                                     layer_infos=layer_infos, 
                                     aggregator_type="meanpool",
                                     model_size=FLAGS.model_size,
                                     identity_dim = FLAGS.identity_dim,
                                     logging=True)

    elif FLAGS.model == 'n2v':
        model = Node2VecModel(placeholders, features.shape[0],
                                       minibatch.deg,
                                       #2x because graphsage uses concat
                                       nodevec_dim=2*FLAGS.dim_1,
                                       lr=FLAGS.learning_rate)
    else:
        raise Exception('Error: model name unrecognized.')

    config = tf.ConfigProto(log_device_placement=FLAGS.log_device_placement)
    config.gpu_options.allow_growth = True
    #config.gpu_options.per_process_gpu_memory_fraction = GPU_MEM_FRACTION
    config.allow_soft_placement = True
    
    # Initialize session
    sess = tf.Session(config=config)
    merged = tf.summary.merge_all()
    summary_writer = tf.summary.FileWriter(log_dir(sampler_name), sess.graph)
    
    
    # Save model
    saver = tf.train.Saver()
    model_path =  './model/unsup-' + FLAGS.train_prefix.split('/')[-1] + '-' + FLAGS.model_prefix + '-' + sampler_name
    model_path += "/{model:s}_{model_size:s}_{lr:0.4f}/".format(
            model=FLAGS.model,
            model_size=FLAGS.model_size,
            lr=FLAGS.learning_rate)

    if not os.path.exists(model_path):
        os.makedirs(model_path)

    
    # Init variables
    sess.run(tf.global_variables_initializer(), feed_dict={adj_info_ph: minibatch.adj})
    
    
    # Restore params of ML sampler model
    if sampler_name == 'ML' or sampler_name == 'FastML':
        sampler_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="MLsampler")
        #pdb.set_trace() 
        saver_sampler = tf.train.Saver(var_list=sampler_vars)
        sampler_model_path = './model/MLsampler-unsup-' + FLAGS.train_prefix.split('/')[-1] + '-' + FLAGS.model_prefix
        sampler_model_path += "/{model:s}_{model_size:s}_{lr:0.4f}/".format(
            model=FLAGS.model,
            model_size=FLAGS.model_size,
            lr=FLAGS.learning_rate)

        saver_sampler.restore(sess, sampler_model_path + 'model.ckpt')

   
    # Train model
    
    train_shadow_mrr = None
    shadow_mrr = None

    total_steps = 0
    avg_time = 0.0
    epoch_val_costs = []

    train_adj_info = tf.assign(adj_info, minibatch.adj)
    val_adj_info = tf.assign(adj_info, minibatch.test_adj)

    val_cost_ = []
    val_mrr_ = []
    shadow_mrr_ = []
    duration_ = []
    
    ln_acc = sparse.csr_matrix((adj_shape[0], adj_shape[0]), dtype=np.float32)
    lnc_acc = sparse.csr_matrix((adj_shape[0], adj_shape[0]), dtype=np.int32)
    
    ln_acc = ln_acc.tolil()
    lnc_acc = lnc_acc.tolil()

    for epoch in range(FLAGS.epochs): 
        minibatch.shuffle() 

        iter = 0
        print('Epoch: %04d' % (epoch + 1))
        epoch_val_costs.append(0)
        
        while not minibatch.end():
            # Construct feed dictionary
            feed_dict = minibatch.next_minibatch_feed_dict()

            if feed_dict.values()[0] != FLAGS.batch_size:
                break

            feed_dict.update({placeholders['dropout']: FLAGS.dropout})

            t = time.time()
            # Training step
            outs = sess.run([merged, model.opt_op, model.loss, model.ranks, model.aff_all, 
                    model.mrr, model.outputs1, model.loss_node, model.loss_node_count], feed_dict=feed_dict)
            train_cost = outs[2]
            train_mrr = outs[5]
            
            if train_shadow_mrr is None:
                train_shadow_mrr = train_mrr#
            else:
                train_shadow_mrr -= (1-0.99) * (train_shadow_mrr - train_mrr)

            if iter % FLAGS.validate_iter == 0:
                # Validation
                sess.run(val_adj_info.op)
                val_cost, ranks, val_mrr, duration  = evaluate(sess, model, minibatch, size=FLAGS.validate_batch_size)
                sess.run(train_adj_info.op)
                epoch_val_costs[-1] += val_cost
            
            if shadow_mrr is None:
                shadow_mrr = val_mrr
            else:
                shadow_mrr -= (1-0.99) * (shadow_mrr - val_mrr)
            
            val_cost_.append(val_cost)
            val_mrr_.append(val_mrr)
            shadow_mrr_.append(shadow_mrr)
            duration_.append(duration)


            if total_steps % FLAGS.print_every == 0:
                summary_writer.add_summary(outs[0], total_steps)
    
            # Print results
            avg_time = (avg_time * total_steps + time.time() - t) / (total_steps + 1)

            if total_steps % FLAGS.print_every == 0:
                print("Iter:", '%04d' % iter, 
                      "train_loss=", "{:.5f}".format(train_cost),
                      "train_mrr=", "{:.5f}".format(train_mrr), 
                      "train_mrr_ema=", "{:.5f}".format(train_shadow_mrr), # exponential moving average
                      "val_loss=", "{:.5f}".format(val_cost),
                      "val_mrr=", "{:.5f}".format(val_mrr), 
                      "val_mrr_ema=", "{:.5f}".format(shadow_mrr), # exponential moving average
                      "time=", "{:.5f}".format(avg_time))

            
            ln = outs[7].values
            ln_idx = outs[7].indices
            ln_acc[ln_idx[:,0], ln_idx[:,1]] += ln

            lnc = outs[8].values
            lnc_idx = outs[8].indices
            lnc_acc[lnc_idx[:,0], lnc_idx[:,1]] += lnc
                
                
            iter += 1
            total_steps += 1

            if total_steps > FLAGS.max_total_steps:
                break

        if total_steps > FLAGS.max_total_steps:
                break


    print("Validation per epoch in training")
    for ep in range(FLAGS.epochs):
        print("Epoch: %04d"%ep, " val_cost={:.5f}".format(val_cost_[ep]), " val_mrr={:.5f}".format(val_mrr_[ep]), " val_mrr_ema={:.5f}".format(shadow_mrr_[ep]), " duration={:.5f}".format(duration_[ep]))
 
    print("Optimization Finished!")
    
    # Save model
    save_path = saver.save(sess, model_path+'model.ckpt')
    print ('model is saved at %s'%save_path)


    # Save loss node and count
    loss_node_path = './loss_node/unsup-' + FLAGS.train_prefix.split('/')[-1] + '-' + FLAGS.model_prefix + '-' + sampler_name
    loss_node_path += "/{model:s}_{model_size:s}_{lr:0.4f}/".format(
            model=FLAGS.model,
            model_size=FLAGS.model_size,
            lr=FLAGS.learning_rate)
    if not os.path.exists(loss_node_path):
        os.makedirs(loss_node_path)

    loss_node = sparse.save_npz(loss_node_path + 'loss_node.npz', sparse.csr_matrix(ln_acc))
    loss_node_count = sparse.save_npz(loss_node_path + 'loss_node_count.npz', sparse.csr_matrix(lnc_acc))
    print ('loss and count per node is saved at %s'%loss_node_path)    
    
    
    
    if FLAGS.save_embeddings:
        sess.run(val_adj_info.op)

        save_val_embeddings(sess, model, minibatch, FLAGS.validate_batch_size, log_dir(sampler_name))

        if FLAGS.model == "n2v":
            # stopping the gradient for the already trained nodes
            train_ids = tf.constant([[id_map[n]] for n in G.nodes_iter() if not G.node[n]['val'] and not G.node[n]['test']],
                    dtype=tf.int32)
            test_ids = tf.constant([[id_map[n]] for n in G.nodes_iter() if G.node[n]['val'] or G.node[n]['test']], 
                    dtype=tf.int32)
            update_nodes = tf.nn.embedding_lookup(model.context_embeds, tf.squeeze(test_ids))
            no_update_nodes = tf.nn.embedding_lookup(model.context_embeds,tf.squeeze(train_ids))
            update_nodes = tf.scatter_nd(test_ids, update_nodes, tf.shape(model.context_embeds))
            no_update_nodes = tf.stop_gradient(tf.scatter_nd(train_ids, no_update_nodes, tf.shape(model.context_embeds)))
            model.context_embeds = update_nodes + no_update_nodes
            sess.run(model.context_embeds)

            # run random walks
            from graphsage.utils import run_random_walks
            nodes = [n for n in G.nodes_iter() if G.node[n]["val"] or G.node[n]["test"]]
            start_time = time.time()
            pairs = run_random_walks(G, nodes, num_walks=50)
            walk_time = time.time() - start_time

            test_minibatch = EdgeMinibatchIterator(G, 
                id_map,
                placeholders, batch_size=FLAGS.batch_size,
                max_degree=FLAGS.max_degree, 
                num_neg_samples=FLAGS.neg_sample_size,
                context_pairs = pairs,
                n2v_retrain=True,
                fixed_n2v=True)
            
            start_time = time.time()
            print("Doing test training for n2v.")
            test_steps = 0
            for epoch in range(FLAGS.n2v_test_epochs):
                test_minibatch.shuffle()
                while not test_minibatch.end():
                    feed_dict = test_minibatch.next_minibatch_feed_dict()
                    feed_dict.update({placeholders['dropout']: FLAGS.dropout})
                    outs = sess.run([model.opt_op, model.loss, model.ranks, model.aff_all, 
                        model.mrr, model.outputs1], feed_dict=feed_dict)
                    if test_steps % FLAGS.print_every == 0:
                        print("Iter:", '%04d' % test_steps, 
                              "train_loss=", "{:.5f}".format(outs[1]),
                              "train_mrr=", "{:.5f}".format(outs[-2]))
                    test_steps += 1
            train_time = time.time() - start_time
            save_val_embeddings(sess, model, minibatch, FLAGS.validate_batch_size, log_dir(sampler_name), mod="-test")
            print("Total time: ", train_time+walk_time)
            print("Walk time: ", walk_time)
            print("Train time: ", train_time)
Ejemplo n.º 9
0
def train(train_data,
          log_dir,
          Theta=None,
          fixed_neigh_weights=None,
          test_data=None,
          neg_sample_weights=None):
    G = train_data[0]
    Gneg = Graph_complement(G)
    features = train_data[1]
    id_map = train_data[2]
    Adj_mat = Edges_to_Adjacency_mat(G.edges(), len(G.nodes()))
    #     print('A in unsup: ', Adj_mat)
    #     negAdj_mat = Edges_to_Adjacency_mat(Gneg.edges(), len(Gneg.nodes()))
    FLAGS.batch_size = batch_size_def(len(G.edges()))
    FLAGS.negbatch_size = batch_size_def(len(Gneg.edges()))
    FLAGS.samples_1 = min(25, len(G.nodes()))
    FLAGS.samples_2 = min(10, len(G.nodes()))
    FLAGS.negsamples_1 = min(25, len(Gneg.nodes()))
    FLAGS.negsamples_2 = min(10, len(Gneg.nodes()))
    FLAGS.neg_sample_size = FLAGS.batch_size

    if (len(G.nodes()) <= small_big_threshold):
        FLAGS.model_size = 'small'
    else:
        FLAGS.model_size = 'big'

    if not features is None:
        # pad with dummy zero vector
        features = np.vstack([features, np.zeros((features.shape[1], ))])

    context_pairs = train_data[3] if FLAGS.random_context else None
    placeholders = construct_placeholders()
    #print('placeholders: ', placeholders)
    minibatch = EdgeMinibatchIterator(G,
                                      id_map,
                                      placeholders,
                                      batch_size=FLAGS.batch_size,
                                      max_degree=FLAGS.max_degree,
                                      num_neg_samples=FLAGS.neg_sample_size,
                                      context_pairs=context_pairs)
    aggbatch_size = len(minibatch.agg_batch_Z1)
    negaggbatch_size = len(minibatch.agg_batch_Z3)
    negminibatch = EdgeMinibatchIterator(Gneg,
                                         id_map,
                                         placeholders,
                                         batch_size=FLAGS.negbatch_size,
                                         max_degree=FLAGS.max_degree,
                                         num_neg_samples=FLAGS.neg_sample_size,
                                         context_pairs=context_pairs)

    #adj_info = tf.Variable(placeholders['adj_info_ph'], trainable=False, name="adj_info")
    adj_info_ph = tf.placeholder(tf.int32,
                                 shape=minibatch.adj.shape,
                                 name="adj_info_ph")
    adj_info = tf.Variable(adj_info_ph, trainable=False, name="adj_info")

    if FLAGS.model == 'graphsage_mean':
        # Create model
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1,
                     FLAGS.negsamples_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2,
                     FLAGS.negsamples_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   Adj_mat=Adj_mat,
                                   non_edges=Gneg.edges(),
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   logging=True,
                                   fixed_theta_1=Theta,
                                   fixed_neigh_weights=fixed_neigh_weights,
                                   neg_sample_weights=neg_sample_weights,
                                   aggbatch_size=aggbatch_size,
                                   negaggbatch_size=negaggbatch_size)
    elif FLAGS.model == 'gcn':
        # Create model
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, 2 * FLAGS.dim_1,
                     FLAGS.negsamples_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, 2 * FLAGS.dim_2,
                     FLAGS.negsamples_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   aggregator_type="gcn",
                                   Adj_mat=Adj_mat,
                                   non_edges=Gneg.edges(),
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   concat=False,
                                   logging=True,
                                   fixed_theta_1=Theta,
                                   fixed_neigh_weights=fixed_neigh_weights,
                                   neg_sample_weights=neg_sample_weights,
                                   aggbatch_size=aggbatch_size,
                                   negaggbatch_size=negaggbatch_size)

    elif FLAGS.model == 'graphsage_seq':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1,
                     FLAGS.negsamples_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2,
                     FLAGS.negsamples_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   identity_dim=FLAGS.identity_dim,
                                   aggregator_type="seq",
                                   Adj_mat=Adj_mat,
                                   non_edges=Gneg.edges(),
                                   model_size=FLAGS.model_size,
                                   logging=True,
                                   fixed_theta_1=Theta,
                                   fixed_neigh_weights=fixed_neigh_weights,
                                   neg_sample_weights=neg_sample_weights,
                                   aggbatch_size=aggbatch_size,
                                   negaggbatch_size=negaggbatch_size)

    elif FLAGS.model == 'graphsage_maxpool':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1,
                     FLAGS.negsamples_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2,
                     FLAGS.negsamples_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   aggregator_type="maxpool",
                                   Adj_mat=Adj_mat,
                                   non_edges=Gneg.edges(),
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   logging=True,
                                   fixed_theta_1=Theta,
                                   fixed_neigh_weights=fixed_neigh_weights,
                                   neg_sample_weights=neg_sample_weights,
                                   aggbatch_size=aggbatch_size,
                                   negaggbatch_size=negaggbatch_size)

    elif FLAGS.model == 'graphsage_meanpool':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1,
                     FLAGS.negsamples_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2,
                     FLAGS.negsamples_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   aggregator_type="meanpool",
                                   Adj_mat=Adj_mat,
                                   non_edges=Gneg.edges(),
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   logging=True,
                                   fixed_theta_1=Theta,
                                   fixed_neigh_weights=fixed_neigh_weights,
                                   neg_sample_weights=neg_sample_weights,
                                   aggbatch_size=aggbatch_size,
                                   negaggbatch_size=negaggbatch_size)

    elif FLAGS.model == 'n2v':
        model = Node2VecModel(
            placeholders,
            features.shape[0],
            minibatch.deg,
            #2x because graphsage uses concat
            nodevec_dim=2 * FLAGS.dim_1,
            lr=FLAGS.learning_rate)
    else:
        raise Exception('Error: model name unrecognized.')

    config = tf.ConfigProto(log_device_placement=FLAGS.log_device_placement)
    config.gpu_options.allow_growth = True
    #config.gpu_options.per_process_gpu_memory_fraction = GPU_MEM_FRACTION
    config.allow_soft_placement = True

    # Initialize session
    sess = tf.Session(config=config)
    merged = tf.summary.merge_all()
    #summary_writer = tf.summary.FileWriter(log_dir(), sess.graph)

    # Init variables
    #minibatch.adj = minibatch.adj.astype(np.int32)
    #print('minibatch.adj.shape: %s, dtype: %s' % (minibatch.adj.shape, np.ndarray.dtype(minibatch.adj)))
    sess.run(tf.global_variables_initializer(),
             feed_dict={adj_info_ph: minibatch.adj})

    # Train model

    train_shadow_mrr = None
    shadow_mrr = None

    total_steps = 0
    avg_time = 0.0
    epoch_val_costs = []
    train_adj_info = tf.assign(adj_info, minibatch.adj)
    val_adj_info = tf.assign(adj_info, minibatch.test_adj)

    print_flag = False
    if (Theta is None):
        print_flag = True

    for epoch in range(FLAGS.epochs):

        minibatch.shuffle()
        negminibatch.shuffle()
        iter = 0
        #         print('Epoch: %04d' % (epoch + 1))
        epoch_val_costs.append(0)
        if (FLAGS.model_size == 'big'):
            whichbatch = minibatch
        elif (len(G.edges()) > len(Gneg.edges())):
            whichbatch = minibatch
            opwhichbatch = negminibatch
        else:
            whichbatch = negminibatch
            opwhichbatch = minibatch

        while (not whichbatch.end()):
            if (FLAGS.model_size == 'small' and opwhichbatch.end()):
                opwhichbatch.shuffle()
            # Construct feed dictionary

            feed_dict = minibatch.next_minibatch_feed_dict()
            negfeed_dict = negminibatch.next_minibatch_feed_dict()

            if (True):
                feed_dict.update({
                    placeholders['negbatch1']:
                    negfeed_dict[placeholders['batch1']]
                })
                feed_dict.update({
                    placeholders['negbatch2']:
                    negfeed_dict[placeholders['batch2']]
                })
                feed_dict.update({
                    placeholders['negbatch_size']:
                    negfeed_dict[placeholders['batch_size']]
                })
            else:
                batch1 = feed_dict[placeholders['batch1']]
                feed_dict.update({placeholders['negbatch1']: batch1})
                feed_dict.update({
                    placeholders['negbatch2']:
                    negminibatch.feed_dict_negbatch(batch1)
                })
                feed_dict.update({
                    placeholders['negbatch_size']:
                    negfeed_dict[placeholders['batch_size']]
                })

            if (Theta is not None):
                break
            if (total_steps == 0):
                print_summary(model, feed_dict, sess, Adj_mat, 'first',
                              print_flag)

            t = time.time()
            # Training step
            outs = sess.run(
                [
                    merged, model.opt_op, model.loss, model.ranks,
                    model.aff_all, model.mrr, model.outputs1, model.outputs2,
                    model.negoutputs1, model.negoutputs2
                ],
                feed_dict=feed_dict
            )  #, model.current_similarity   , model.learned_vars['theta_1']  # , model.aggregation
            train_cost = outs[2]
            train_mrr = outs[5]

            #             Z = outs[8]
            #             outs = np.concatenate((outs[6],outs[7]), axis=0)
            #             indices = np.concatenate((feed_dict[placeholders['batch1']],feed_dict[placeholders['batch2']]))
            #             Z = np.zeros((len(G.nodes()),FLAGS.dim_2*2))
            #             for node in G.nodes():
            #                 Z[node,:] = outs[int(np.argwhere(indices==node)[0]),:]#[0:len(G.nodes())] # 8
            #             print('Z shape: ', Z.shape)

            #             if train_shadow_mrr is None:
            #                 train_shadow_mrr = train_mrr #
            #             else:
            #                 train_shadow_mrr -= (1-0.99) * (train_shadow_mrr - train_mrr)
            #
            #             if iter % FLAGS.validate_iter == 0:
            #                 # Validation
            #                 sess.run(val_adj_info.op)
            #                 val_cost, ranks, val_mrr, duration = evaluate(sess, model, minibatch, size=FLAGS.validate_batch_size, negminibatch_iter=negminibatch, placeholders=placeholders)
            #                 sess.run(train_adj_info.op)
            #                 epoch_val_costs[-1] += val_cost
            #
            #             if shadow_mrr is None:
            #                 shadow_mrr = val_mrr
            #             else:
            #                 shadow_mrr -= (1-0.99) * (shadow_mrr - val_mrr)

            #             if total_steps % FLAGS.print_every == 0:
            #                 summary_writer.add_summary(outs[0], total_steps)

            # Print results
            avg_time = (avg_time * total_steps + time.time() -
                        t) / (total_steps + 1)

            if total_steps % FLAGS.print_every == 0:
                print('loss: ', outs[2])
#                 print("Iter:", '%04d' % iter,
#                       "train_loss=", "{:.5f}".format(train_cost),
#                       "train_mrr=", "{:.5f}".format(train_mrr),
#                       "train_mrr_ema=", "{:.5f}".format(train_shadow_mrr), # exponential moving average
#                       "val_loss=", "{:.5f}".format(val_cost),
#                       "val_mrr=", "{:.5f}".format(val_mrr),
#                       "val_mrr_ema=", "{:.5f}".format(shadow_mrr), # exponential moving average
#                       "time=", "{:.5f}".format(avg_time))

#             similarity_weights = outs[7]
#             [new_adj_info, new_batch_edges] = sess.run([model.adj_info, \
#                                                                 model.new_batch_edges], \
#                                                                   feed_dict=feed_dict)
#             minibatch.graph_update(new_adj_info, new_batch_edges)

            iter += 1
            total_steps += 1

            if total_steps > FLAGS.max_total_steps:
                break

        if (Theta is not None):
            break
        if total_steps > FLAGS.max_total_steps:
            break


#     print("SGD Optimization Finished!")

#     feed_dict = dict()
#     feed_dict.update({placeholders['batch_size'] : len(G.nodes())})

#     minibatch = EdgeMinibatchIterator(G,
#             id_map,
#             placeholders, batch_size = len(G.edges()),
#             max_degree = FLAGS.max_degree,
#             num_neg_samples = FLAGS.neg_sample_size,
#             context_pairs = context_pairs)
#     feed_dict = minibatch.next_minibatch_feed_dict()
#     _, Z = sess.run([merged, model.aggregation], feed_dict=feed_dict) #, model.concat   ,     aggregator_cls.vars

#     Z_tilde = np.repeat(Z, [len(G.nodes())], axis=0)
#     Z_tilde_tilde = np.tile(Z, (len(G.nodes()),1))
#     final_theta_1 = Quadratic_SDP_solver (Z_tilde, Z_tilde_tilde, FLAGS.dim_1*2, FLAGS.dim_2*2)
# #     Z_centralized = Z - np.mean(Z, axis=0)
# #     final_adj_matrix = np.abs(np.matmul(Z_centralized, np.matmul(final_theta_1, np.transpose(Z_centralized))))
#     final_adj_matrix = np.matmul(Z, np.matmul(final_theta_1, np.transpose(Z)))
#     feed_dict = minibatch.next_minibatch_feed_dict()
#     feed_dict.update({placeholders['dropout']: FLAGS.dropout})
#     feed_dict.update({placeholders['all_nodes']: all_nodes})

#     FLAGS.batch_size = len(G.edges())
#     FLAGS.negbatch_size = len(Gneg.edges())
#     FLAGS.samples_1 = len(G.nodes())
#     FLAGS.samples_2 = len(G.nodes())
#     FLAGS.negsamples_1 = len(Gneg.nodes())
#     FLAGS.negsamples_2 = len(Gneg.nodes())
#     feed_dict = minibatch.batch_feed_dict(G.edges())
#     negfeed_dict = negminibatch.batch_feed_dict(Gneg.edges())
#     feed_dict.update({placeholders['negbatch1']: negfeed_dict[placeholders['batch1']]})
#     feed_dict.update({placeholders['negbatch2']: negfeed_dict[placeholders['batch2']]})
#     feed_dict.update({placeholders['negbatch_size']: negfeed_dict[placeholders['batch_size']]})
#     if(outs is not None):
#         print('loss: ', outs[2])

    final_adj_matrix, final_theta_1, Z, loss, U = print_summary(
        model, feed_dict, sess, Adj_mat, 'last', print_flag)

    #print('Z shape: ', Z.shape)

    if FLAGS.save_embeddings:
        sess.run(val_adj_info.op)

        save_val_embeddings(sess, model, minibatch, FLAGS.validate_batch_size,
                            log_dir())

        if FLAGS.model == "n2v":
            # stopping the gradient for the already trained nodes
            train_ids = tf.constant(
                [[id_map[n]] for n in G.nodes_iter()
                 if not G.node[n]['val'] and not G.node[n]['test']],
                dtype=tf.int32)
            test_ids = tf.constant([[id_map[n]] for n in G.nodes_iter()
                                    if G.node[n]['val'] or G.node[n]['test']],
                                   dtype=tf.int32)
            update_nodes = tf.nn.embedding_lookup(model.context_embeds,
                                                  tf.squeeze(test_ids))
            no_update_nodes = tf.nn.embedding_lookup(model.context_embeds,
                                                     tf.squeeze(train_ids))
            update_nodes = tf.scatter_nd(test_ids, update_nodes,
                                         tf.shape(model.context_embeds))
            no_update_nodes = tf.stop_gradient(
                tf.scatter_nd(train_ids, no_update_nodes,
                              tf.shape(model.context_embeds)))
            model.context_embeds = update_nodes + no_update_nodes
            sess.run(model.context_embeds)

            # run random walks
            from graphsage.utils import run_random_walks
            nodes = [
                n for n in G.nodes_iter()
                if G.node[n]["val"] or G.node[n]["test"]
            ]
            start_time = time.time()
            pairs = run_random_walks(G, nodes, num_walks=50)
            walk_time = time.time() - start_time

            test_minibatch = EdgeMinibatchIterator(
                G,
                id_map,
                placeholders,
                batch_size=FLAGS.batch_size,
                max_degree=FLAGS.max_degree,
                num_neg_samples=FLAGS.neg_sample_size,
                context_pairs=pairs,
                n2v_retrain=True,
                fixed_n2v=True)

            start_time = time.time()
            print("Doing test training for n2v.")
            test_steps = 0
            for epoch in range(FLAGS.n2v_test_epochs):
                test_minibatch.shuffle()
                while not test_minibatch.end():
                    feed_dict = test_minibatch.next_minibatch_feed_dict()
                    feed_dict.update({placeholders['dropout']: FLAGS.dropout})
                    outs = sess.run([
                        model.opt_op, model.loss, model.ranks, model.aff_all,
                        model.mrr, model.outputs1
                    ],
                                    feed_dict=feed_dict)
                    if test_steps % FLAGS.print_every == 0:
                        print("Iter:", '%04d' % test_steps, "train_loss=",
                              "{:.5f}".format(outs[1]), "train_mrr=",
                              "{:.5f}".format(outs[-2]))
                    test_steps += 1
            train_time = time.time() - start_time
            save_val_embeddings(sess,
                                model,
                                minibatch,
                                FLAGS.validate_batch_size,
                                log_dir(),
                                mod="-test")
            print("Total time: ", train_time + walk_time)
            print("Walk time: ", walk_time)
            print("Train time: ", train_time)
    #del adj_info_ph, adj_info,placeholders
    return final_adj_matrix, G, final_theta_1, Z, loss, U  #, learned_vars
Ejemplo n.º 10
0
def train(train_data, G_local,test_data=None):
    G = train_data[0]
    features = train_data[1]
    id_map = train_data[2]

    if not features is None:
        # pad with dummy zero vector
        features = np.vstack([features, np.zeros((features.shape[1],))])

    context_pairs = train_data[3] if FLAGS.random_context else None
    placeholders = construct_placeholders()
    minibatch = EdgeMinibatchIterator(G, G_local,
            id_map,
            placeholders, batch_size=FLAGS.batch_size,
            max_degree=FLAGS.max_degree, 
            num_neg_samples=FLAGS.neg_sample_size,
            context_pairs = context_pairs)
    edge_file = open(log_dir() + "/" +str(FLAGS.graph_id)+"_edge_detail_"+FLAGS.train_worker+".txt", "w")
    edge_file.write (str("train-edges"))
    edge_file.write (str(len(minibatch.train_edges)))
    edge_file.write (str("valid-edges"))
    edge_file.write (str(len(minibatch.val_edges)))
    edge_file.write (str("test-edges"))
    edge_file.write (str(len(minibatch.test_edges)))

    adj_info_ph = tf.placeholder(tf.int32, shape=minibatch.adj.shape)
    adj_info = tf.Variable(adj_info_ph, trainable=False, name="adj_info")

    if FLAGS.model == 'graphsage_mean':
        # Create model
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)]

        model = SampleAndAggregate(placeholders, 
                                     features,
                                     adj_info,
                                     minibatch.deg,
                                     layer_infos=layer_infos, 
                                     model_size=FLAGS.model_size,
                                     identity_dim = FLAGS.identity_dim,
                                     logging=True)
    elif FLAGS.model == 'gcn':
        # Create model
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, 2*FLAGS.dim_1),
                            SAGEInfo("node", sampler, FLAGS.samples_2, 2*FLAGS.dim_2)]

        model = SampleAndAggregate(placeholders, 
                                     features,
                                     adj_info,
                                     minibatch.deg,
                                     layer_infos=layer_infos, 
                                     aggregator_type="gcn",
                                     model_size=FLAGS.model_size,
                                     identity_dim = FLAGS.identity_dim,
                                     concat=False,
                                     logging=True)

    elif FLAGS.model == 'graphsage_seq':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)]

        model = SampleAndAggregate(placeholders, 
                                     features,
                                     adj_info,
                                     minibatch.deg,
                                     layer_infos=layer_infos, 
                                     identity_dim = FLAGS.identity_dim,
                                     aggregator_type="seq",
                                     model_size=FLAGS.model_size,
                                     logging=True)

    elif FLAGS.model == 'graphsage_maxpool':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)]

        model = SampleAndAggregate(placeholders, 
                                    features,
                                    adj_info,
                                    minibatch.deg,
                                     layer_infos=layer_infos, 
                                     aggregator_type="maxpool",
                                     model_size=FLAGS.model_size,
                                     identity_dim = FLAGS.identity_dim,
                                     logging=True)
    elif FLAGS.model == 'graphsage_meanpool':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)]

        model = SampleAndAggregate(placeholders, 
                                    features,
                                    adj_info,
                                    minibatch.deg,
                                     layer_infos=layer_infos, 
                                     aggregator_type="meanpool",
                                     model_size=FLAGS.model_size,
                                     identity_dim = FLAGS.identity_dim,
                                     logging=True)

    elif FLAGS.model == 'n2v':
        model = Node2VecModel(placeholders, features.shape[0],
                                       minibatch.deg,
                                       #2x because graphsage uses concat
                                       nodevec_dim=2*FLAGS.dim_1,
                                       lr=FLAGS.learning_rate)
    else:
        raise Exception('Error: model name unrecognized.')

    config = tf.ConfigProto(log_device_placement=FLAGS.log_device_placement)
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True
    
    # Initialize session
    sess = tf.Session(config=config)
    merged = tf.summary.merge_all()
    saver = tf.train.Saver()
     
    # Init variables
    sess.run(tf.global_variables_initializer(), feed_dict={adj_info_ph: minibatch.adj,model.features_placeholder: features})

    # Train model
    
    train_shadow_mrr = None
    shadow_mrr = None
    test_shadow_mrr = None

    total_steps = 0
    avg_time = 0.0
    epoch_val_costs = []

    train_adj_info = tf.assign(adj_info, minibatch.adj)
    val_adj_info = tf.assign(adj_info, minibatch.test_adj)
    for epoch in range(FLAGS.epochs): 
        minibatch.shuffle() 

        iter = 0
        print('Epoch: %04d' % (epoch + 1))
        epoch_val_costs.append(0)
        while not minibatch.end():
            # Construct feed dictionary
            feed_dict = minibatch.next_minibatch_feed_dict()
            feed_dict.update({placeholders['dropout']: FLAGS.dropout})

            t = time.time()
            # Training step
            outs = sess.run([merged, model.opt_op, model.loss, model.ranks, model.aff_all, 
                    model.mrr, model.outputs1], feed_dict=feed_dict)
            train_cost = outs[2]
            train_mrr = outs[5]
            if train_shadow_mrr is None:
                train_shadow_mrr = train_mrr#
            else:
                train_shadow_mrr -= (1-0.99) * (train_shadow_mrr - train_mrr)

            if iter % FLAGS.validate_iter == 0:
                # Validation
                sess.run(val_adj_info.op)
                val_cost, ranks, val_mrr, duration  = evaluate(sess, model, minibatch, size=FLAGS.validate_batch_size)
                sess.run(train_adj_info.op)
                epoch_val_costs[-1] += val_cost
            if shadow_mrr is None:
                shadow_mrr = val_mrr
            else:
                shadow_mrr -= (1-0.99) * (shadow_mrr - val_mrr)

            avg_time = (avg_time * total_steps + time.time() - t) / (total_steps + 1)

            if total_steps % FLAGS.print_every == 0:
                print("Epoch: ,%04d" % (epoch + 1),
                      "Iter:", '%04d' % iter,
                      "train_loss=", "{:.5f}".format(train_cost),
                      "train_mrr=", "{:.5f}".format(train_mrr), 
                      "train_mrr_ema=", "{:.5f}".format(train_shadow_mrr), # exponential moving average
                      "val_loss=", "{:.5f}".format(val_cost),
                      "val_mrr=", "{:.5f}".format(val_mrr), 
                      "val_mrr_ema=", "{:.5f}".format(shadow_mrr), # exponential moving average
                      "time=", "{:.5f}".format(avg_time))
                with open(log_dir() + "/"+str(FLAGS.graph_id)+"_validate_stats_"+FLAGS.train_worker+".txt", "a+") as fp:
                    fp.write("train_loss={:.5f} train_mrr={:.5f} train_mrr_ema={:.5f} val_loss={:.5f} val_mrr={:.5f} val_mrr_ema={:.5f} time={:.5f}".
                             format(train_cost, train_mrr, train_shadow_mrr,val_cost,val_mrr,shadow_mrr, avg_time))
                    fp.write("\n")


            iter += 1
            total_steps += 1

            if total_steps > FLAGS.max_total_steps:
                break

        if total_steps > FLAGS.max_total_steps:
                break

    train_end_file = open(log_dir() + "/" +str(FLAGS.graph_id)+"_train_end_time_"+FLAGS.train_worker+".txt", "w")
    trainendDT = datetime.datetime.now()
    print (str(trainendDT))
    train_end_file.write (str(trainendDT))

    print("Optimization Finished!")
    if FLAGS.save_embeddings:
        sess.run(val_adj_info.op)

        save_val_embeddings(sess, model, minibatch, FLAGS.validate_batch_size, log_dir())
        save_central_val_embeddings(sess, model, minibatch, FLAGS.validate_batch_size, log_dir())

        test_cost, test_ranks, test_mrr, test_duration  = evaluate_test(sess, model, minibatch, size=FLAGS.validate_batch_size)
        if test_shadow_mrr is None:
            test_shadow_mrr = test_mrr
        else:
            test_shadow_mrr -= (1-0.99) * (test_shadow_mrr - test_mrr)

        print("Full Test stats:",
              "test_loss=", "{:.5f}".format(test_cost),
              "test_mrr=", "{:.5f}".format(test_mrr),
              "test_mrr_ema=", "{:.5f}".format(test_shadow_mrr),
              "time=", "{:.5f}".format(test_duration))
        end_file = open(log_dir() + "/" +str(FLAGS.graph_id)+"_end_time_"+FLAGS.train_worker+".txt", "w")
        endDT = datetime.datetime.now()
        print (str(endDT))
        end_file.write (str(endDT))

        with open(log_dir() + "/" +str(FLAGS.graph_id)+"_test_stats_"+FLAGS.train_worker+".txt", "w") as fp:
            fp.write("test_loss={:.5f} test_mrr={:.5f} test_mrr_ema={:.5f} time={:.5f}".
                     format(test_cost, test_mrr, test_shadow_mrr, test_duration))

        test_edge_set = [e for e in G.edges() if G[e[0]][e[1]]['testing']]

        with open(log_dir() + "/"+str(FLAGS.graph_id)+"_TEST_EDGE_SET_"+FLAGS.train_worker+".txt", "a+") as fp:
            for e in test_edge_set:
                fp.write(str(e[0])+" "+str(e[1]))
                fp.write("\n")

        if FLAGS.model == "n2v":
            # stopping the gradient for the already trained nodes
            train_ids = tf.constant([[id_map[n]] for n in G.nodes_iter() if not G.node[n]['val'] and not G.node[n]['test']],
                    dtype=tf.int32)
            test_ids = tf.constant([[id_map[n]] for n in G.nodes_iter() if G.node[n]['val'] or G.node[n]['test']], 
                    dtype=tf.int32)
            update_nodes = tf.nn.embedding_lookup(model.context_embeds, tf.squeeze(test_ids))
            no_update_nodes = tf.nn.embedding_lookup(model.context_embeds,tf.squeeze(train_ids))
            update_nodes = tf.scatter_nd(test_ids, update_nodes, tf.shape(model.context_embeds))
            no_update_nodes = tf.stop_gradient(tf.scatter_nd(train_ids, no_update_nodes, tf.shape(model.context_embeds)))
            model.context_embeds = update_nodes + no_update_nodes
            sess.run(model.context_embeds)

            # run random walks
            from graphsage.utils import run_random_walks
            nodes = [n for n in G.nodes_iter() if G.node[n]["val"] or G.node[n]["test"]]
            start_time = time.time()
            pairs = run_random_walks(G, nodes, num_walks=50)
            walk_time = time.time() - start_time

            test_minibatch = EdgeMinibatchIterator(G, G_local,
                id_map,
                placeholders, batch_size=FLAGS.batch_size,
                max_degree=FLAGS.max_degree, 
                num_neg_samples=FLAGS.neg_sample_size,
                context_pairs = pairs,
                n2v_retrain=True,
                fixed_n2v=True)
            
            start_time = time.time()
            print("Doing test training for n2v.")
            test_steps = 0
            for epoch in range(FLAGS.n2v_test_epochs):
                test_minibatch.shuffle()
                while not test_minibatch.end():
                    feed_dict = test_minibatch.next_minibatch_feed_dict()
                    feed_dict.update({placeholders['dropout']: FLAGS.dropout})
                    outs = sess.run([model.opt_op, model.loss, model.ranks, model.aff_all, 
                        model.mrr, model.outputs1], feed_dict=feed_dict)
                    if test_steps % FLAGS.print_every == 0:
                        print("Iter:", '%04d' % test_steps, 
                              "train_loss=", "{:.5f}".format(outs[1]),
                              "train_mrr=", "{:.5f}".format(outs[-2]))
                    test_steps += 1
            train_time = time.time() - start_time
            save_val_embeddings(sess, model, minibatch, FLAGS.validate_batch_size, log_dir(), mod="-test")
            print("Total time: ", train_time+walk_time)
            print("Walk time: ", walk_time)
            print("Train time: ", train_time)
def train(train_data, test_data=None):
    G = train_data[0]
    features = train_data[1]
    id_map = train_data[2]
    prob_matrix = train_data[5]

    if not features is None:
        # pad with dummy zero vector
        features = np.vstack([features, np.zeros((features.shape[1], ))])

    context_pairs = train_data[3] if FLAGS.random_context else None
    placeholders = construct_placeholders()
    minibatch = EdgeMinibatchIterator(G,
                                      id_map,
                                      placeholders,
                                      prob_matrix,
                                      batch_size=FLAGS.batch_size,
                                      max_degree=FLAGS.max_degree,
                                      num_neg_samples=FLAGS.neg_sample_size,
                                      context_pairs=context_pairs)
    adj_info_ph = tf.placeholder(tf.int32, shape=minibatch.adj.shape)
    adj_info = tf.Variable(adj_info_ph, trainable=False, name="adj_info")

    if FLAGS.model == 'graphsage_mean':
        # Create model
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   logging=True)
    elif FLAGS.model == 'gcn':
        # Create model
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, 2 * FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, 2 * FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   aggregator_type="gcn",
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   concat=False,
                                   logging=True)

    elif FLAGS.model == 'graphsage_seq':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   identity_dim=FLAGS.identity_dim,
                                   aggregator_type="seq",
                                   model_size=FLAGS.model_size,
                                   logging=True)

    elif FLAGS.model == 'graphsage_maxpool':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   aggregator_type="maxpool",
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   logging=True)
    elif FLAGS.model == 'graphsage_meanpool':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   aggregator_type="meanpool",
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   logging=True)

    elif FLAGS.model == 'n2v':
        model = Node2VecModel(
            placeholders,
            features.shape[0],
            minibatch.deg,
            #2x because graphsage uses concat
            nodevec_dim=2 * FLAGS.dim_1,
            lr=FLAGS.learning_rate)
    else:
        raise Exception('Error: model name unrecognized.')

    config = tf.ConfigProto(log_device_placement=FLAGS.log_device_placement)
    config.gpu_options.allow_growth = True
    #config.gpu_options.per_process_gpu_memory_fraction = GPU_MEM_FRACTION
    config.allow_soft_placement = True

    # Initialize session
    sess = tf.Session(config=config)
    merged = tf.summary.merge_all()
    summary_writer = tf.summary.FileWriter(log_dir(), sess.graph)

    # Init variables
    sess.run(tf.global_variables_initializer(),
             feed_dict={adj_info_ph: minibatch.adj})

    # Train model

    train_shadow_mrr = None
    shadow_mrr = None

    total_steps = 0
    avg_time = 0.0
    epoch_val_costs = []

    train_adj_info = tf.assign(adj_info, minibatch.adj)
    val_adj_info = tf.assign(adj_info, minibatch.test_adj)

    min_train_loss = 0.0
    fp = open(log_dir() + 'epoch_train_loss.txt', 'w')
    loss_list = []
    epoch_list = []

    for epoch in range(FLAGS.epochs):
        minibatch.shuffle()

        epoch_train_loss = 0.0

        iter = 0
        print('Epoch: %04d' % (epoch + 1))
        epoch_val_costs.append(0)
        while not minibatch.end():
            # Construct feed dictionary
            feed_dict = minibatch.next_minibatch_feed_dict(layer_infos)
            feed_dict.update({placeholders['dropout']: FLAGS.dropout})

            t = time.time()
            # Training step
            outs = sess.run([
                merged, model.opt_op, model.loss, model.ranks, model.aff_all,
                model.mrr, model.outputs1
            ],
                            feed_dict=feed_dict)
            train_cost = outs[2]
            train_mrr = outs[5]

            epoch_train_loss += train_cost

            if train_shadow_mrr is None:
                train_shadow_mrr = train_mrr  #
            else:
                train_shadow_mrr -= (1 - 0.99) * (train_shadow_mrr - train_mrr)

            if total_steps % FLAGS.print_every == 0:
                summary_writer.add_summary(outs[0], total_steps)

            # Print results
            avg_time = (avg_time * total_steps + time.time() -
                        t) / (total_steps + 1)

            if total_steps % FLAGS.print_every == 0:
                print(
                    "Iter:",
                    '%04d' % iter,
                    "train_loss=",
                    "{:.5f}".format(train_cost),
                    "train_mrr=",
                    "{:.5f}".format(train_mrr),
                    "train_mrr_ema=",
                    "{:.5f}".format(
                        train_shadow_mrr),  # exponential moving average
                    "time=",
                    "{:.5f}".format(avg_time))

            iter += 1
            total_steps += 1

            if total_steps > FLAGS.max_total_steps:
                break

        if total_steps > FLAGS.max_total_steps:
            break

        epoch_train_loss = epoch_train_loss / iter

        print('epoch: ' + str(epoch) + '\t' + 'train_loss: ' +
              str(epoch_train_loss) + '\n')
        fp.write('epoch: ' + str(epoch) + '\t' + 'train_loss: ' +
                 str(epoch_train_loss) + '\n')
        loss_list.append(epoch_train_loss)
        epoch_list.append(epoch)

        if epoch == 0:
            min_train_loss = epoch_train_loss
            sess.run(val_adj_info.op)
            node_embeddings, node_list = save_val_embeddings(
                sess, model, minibatch, FLAGS.validate_batch_size, log_dir(),
                layer_infos)
        elif epoch_train_loss < min_train_loss:
            min_train_loss = epoch_train_loss
            sess.run(val_adj_info.op)
            node_embeddings, node_list = save_val_embeddings(
                sess, model, minibatch, FLAGS.validate_batch_size, log_dir(),
                layer_infos)

    print("Optimization Finished!")

    fp.close()

    np.save(log_dir() + "val.npy", node_embeddings)
    with open(log_dir() + "val.txt", "w") as fp:
        fp.write("\n".join(map(str, node_list)))

    plt.plot(epoch_list, loss_list)
    plt.savefig(log_dir() + 'loss_trend.png')
    plt.clf()

    ## t-SNE plot
    ## creating t-SNE embedding feature
    #tsne_feat=TSNE(n_components=2).fit_transform(node_embeddings)
    tsne_feat = node_embeddings

    ## plotting t-SNE embedding
    tsne_df = pd.DataFrame(columns=['tsne0', 'tsne1'])
    tsne_df['tsne0'] = tsne_feat[:, 0]
    tsne_df['tsne1'] = tsne_feat[:, 1]
    plt.figure(figsize=(16, 10))
    sns.scatterplot(x="tsne0",
                    y="tsne1",
                    hue=node_list,
                    palette=sns.color_palette("hls", len(node_list)),
                    data=tsne_df,
                    legend=False,
                    alpha=0.3)
    for i, txt in enumerate(node_list):
        plt.annotate(txt, (tsne_df['tsne0'][i], tsne_df['tsne1'][i]))
    plt.savefig(log_dir() + 't-SNE_plot_node_embeddings.png')
    plt.show()

    if FLAGS.save_embeddings:
        sess.run(val_adj_info.op)

        save_val_embeddings(sess, model, minibatch, FLAGS.validate_batch_size,
                            log_dir())

        if FLAGS.model == "n2v":
            # stopping the gradient for the already trained nodes
            train_ids = tf.constant(
                [[id_map[n]] for n in G.nodes_iter()
                 if not G.node[n]['val'] and not G.node[n]['test']],
                dtype=tf.int32)
            test_ids = tf.constant([[id_map[n]] for n in G.nodes_iter()
                                    if G.node[n]['val'] or G.node[n]['test']],
                                   dtype=tf.int32)
            update_nodes = tf.nn.embedding_lookup(model.context_embeds,
                                                  tf.squeeze(test_ids))
            no_update_nodes = tf.nn.embedding_lookup(model.context_embeds,
                                                     tf.squeeze(train_ids))
            update_nodes = tf.scatter_nd(test_ids, update_nodes,
                                         tf.shape(model.context_embeds))
            no_update_nodes = tf.stop_gradient(
                tf.scatter_nd(train_ids, no_update_nodes,
                              tf.shape(model.context_embeds)))
            model.context_embeds = update_nodes + no_update_nodes
            sess.run(model.context_embeds)

            # run random walks
            from graphsage.utils import run_random_walks
            nodes = [
                n for n in G.nodes_iter()
                if G.node[n]["val"] or G.node[n]["test"]
            ]
            start_time = time.time()
            pairs = run_random_walks(G, nodes, num_walks=50)
            walk_time = time.time() - start_time

            test_minibatch = EdgeMinibatchIterator(
                G,
                id_map,
                placeholders,
                batch_size=FLAGS.batch_size,
                max_degree=FLAGS.max_degree,
                num_neg_samples=FLAGS.neg_sample_size,
                context_pairs=pairs,
                n2v_retrain=True,
                fixed_n2v=True)

            start_time = time.time()
            print("Doing test training for n2v.")
            test_steps = 0
            for epoch in range(FLAGS.n2v_test_epochs):
                test_minibatch.shuffle()
                while not test_minibatch.end():
                    feed_dict = test_minibatch.next_minibatch_feed_dict(
                        layer_infos)
                    feed_dict.update({placeholders['dropout']: FLAGS.dropout})
                    outs = sess.run([
                        model.opt_op, model.loss, model.ranks, model.aff_all,
                        model.mrr, model.outputs1
                    ],
                                    feed_dict=feed_dict)
                    if test_steps % FLAGS.print_every == 0:
                        print("Iter:", '%04d' % test_steps, "train_loss=",
                              "{:.5f}".format(outs[1]), "train_mrr=",
                              "{:.5f}".format(outs[-2]))
                    test_steps += 1
            train_time = time.time() - start_time
            save_val_embeddings(sess,
                                model,
                                minibatch,
                                FLAGS.validate_batch_size,
                                log_dir(),
                                mod="-test")
            print("Total time: ", train_time + walk_time)
            print("Walk time: ", walk_time)
            print("Train time: ", train_time)