def time_for_unsupervised_sampling(G, id_map, walks, num_classes):
    tf.reset_default_graph()
    placeholders = {
        'labels':
        tf.placeholder(tf.float32, shape=(None, num_classes), name='labels'),
        'batch1':
        tf.placeholder(tf.int32, shape=(None, ), name='batch1'),
        'batch2':
        tf.placeholder(tf.int32, shape=(None, ), name='batch2'),
        'dropout':
        tf.placeholder_with_default(0., shape=(), name='dropout'),
        'batch_size':
        tf.placeholder(tf.int32, name='batch_size'),
    }
    minibatch = EdgeMinibatchIterator(G,
                                      id_map,
                                      placeholders,
                                      walks,
                                      batch_size=512,
                                      max_degree=100)
    label = tf.cast(minibatch.placeholders["batch2"], dtype=tf.int64)
    labels = tf.reshape(label, [placeholders['batch_size'], 1])
    neg_samples, _, _ = (tf.nn.fixed_unigram_candidate_sampler(
        true_classes=labels,
        num_true=1,
        num_sampled=20,
        unique=False,
        range_max=len(minibatch.deg),
        distortion=0.75,
        unigrams=minibatch.deg.tolist()))
    pruned_adj_matrix = tf.constant(minibatch.adj, dtype=tf.int32)
    sampler = UniformNeighborSampler(pruned_adj_matrix)

    def get_two_hop_sampled(input_t):
        sample1 = sampler((input_t, 25))
        reshaped_sample1 = tf.reshape(sample1, [
            tf.shape(sample1)[0] * 25,
        ])
        sample2 = sampler((reshaped_sample1, 10))
        return sample2

    source = get_two_hop_sampled(minibatch.placeholders["batch1"])
    target = get_two_hop_sampled(minibatch.placeholders["batch2"])
    neg_samples = get_two_hop_sampled(neg_samples)
    minibatch.shuffle()
    sess = tf.Session()
    start_time = time.time()
    i = 0
    max_steps = 1000
    while not minibatch.end():
        feed_dict = minibatch.next_minibatch_feed_dict()
        sess.run([source, target, neg_samples], feed_dict)
        i = i + 1
        if (i > 1000):
            break
    print("Total number of steps run {}".format(i))
    end_time = time.time()
    print("Sampling time with negative {}".format(end_time - start_time))
    sess.close()
    add_to_dict("UNSSAMPLE", (end_time - start_time))
Esempio n. 2
0
def train(train_data, test_data=None):
    G = train_data[0]
    features = train_data[1]
    id_map = train_data[2]

    if not features is None:
        # pad with dummy zero vector
        features = np.vstack([features, np.zeros((features.shape[1], ))])

    context_pairs = train_data[3] if FLAGS.random_context else None
    placeholders = construct_placeholders()
    minibatch = EdgeMinibatchIterator(G,
                                      id_map,
                                      placeholders,
                                      batch_size=FLAGS.batch_size,
                                      max_degree=FLAGS.max_degree,
                                      num_neg_samples=FLAGS.neg_sample_size,
                                      context_pairs=context_pairs)
    adj_info_ph = tf.placeholder(tf.int32, shape=minibatch.adj.shape)
    adj_info = tf.Variable(adj_info_ph, trainable=False, name="adj_info")

    if FLAGS.model == 'graphsage_mean':
        # Create model
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   logging=True)
    elif FLAGS.model == 'gcn':
        # Create model
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, 2 * FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, 2 * FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   aggregator_type="gcn",
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   concat=False,
                                   logging=True)

    elif FLAGS.model == 'graphsage_seq':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   identity_dim=FLAGS.identity_dim,
                                   aggregator_type="seq",
                                   model_size=FLAGS.model_size,
                                   logging=True)

    elif FLAGS.model == 'graphsage_maxpool':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   aggregator_type="maxpool",
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   logging=True)
    elif FLAGS.model == 'graphsage_meanpool':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   aggregator_type="meanpool",
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   logging=True)

    elif FLAGS.model == 'n2v':
        model = Node2VecModel(
            placeholders,
            features.shape[0],
            minibatch.deg,
            #2x because graphsage uses concat
            nodevec_dim=2 * FLAGS.dim_1,
            lr=FLAGS.learning_rate)
    else:
        raise Exception('Error: model name unrecognized.')

    config = tf.ConfigProto(log_device_placement=FLAGS.log_device_placement)
    config.gpu_options.allow_growth = True
    #config.gpu_options.per_process_gpu_memory_fraction = GPU_MEM_FRACTION
    config.allow_soft_placement = True

    # default to saving all variables : added by wy
    saver = tf.train.Saver()

    # Initialize session
    sess = tf.Session(config=config)
    merged = tf.summary.merge_all()
    summary_writer = tf.summary.FileWriter(log_dir(), sess.graph)

    # Init variables
    sess.run(tf.global_variables_initializer(),
             feed_dict={adj_info_ph: minibatch.adj})

    # Train model
    if FLAGS.isTrain:
        print("== Training Model ==")
        train_shadow_mrr = None
        shadow_mrr = None

        total_steps = 0
        avg_time = 0.0
        epoch_val_costs = []

        train_adj_info = tf.assign(adj_info, minibatch.adj)
        val_adj_info = tf.assign(adj_info, minibatch.test_adj)
        for epoch in range(FLAGS.epochs):
            minibatch.shuffle()

            iter = 0
            print('Epoch: %04d' % (epoch + 1))
            epoch_val_costs.append(0)
            while not minibatch.end():
                # Construct feed dictionary
                feed_dict = minibatch.next_minibatch_feed_dict()
                feed_dict.update({placeholders['dropout']: FLAGS.dropout})

                t = time.time()
                # Training step
                outs = sess.run([
                    merged, model.opt_op, model.loss, model.ranks,
                    model.aff_all, model.mrr, model.outputs1
                ],
                                feed_dict=feed_dict)
                train_cost = outs[2]
                train_mrr = outs[5]
                if train_shadow_mrr is None:
                    train_shadow_mrr = train_mrr  #
                else:
                    train_shadow_mrr -= (1 - 0.99) * (train_shadow_mrr -
                                                      train_mrr)

                if iter % FLAGS.validate_iter == 0:
                    # Validation
                    sess.run(val_adj_info.op)
                    val_cost, ranks, val_mrr, duration = evaluate(
                        sess, model, minibatch, size=FLAGS.validate_batch_size)
                    sess.run(train_adj_info.op)
                    epoch_val_costs[-1] += val_cost
                if shadow_mrr is None:
                    shadow_mrr = val_mrr
                else:
                    shadow_mrr -= (1 - 0.99) * (shadow_mrr - val_mrr)

                if total_steps % FLAGS.print_every == 0:
                    summary_writer.add_summary(outs[0], total_steps)

                # Print results
                avg_time = (avg_time * total_steps + time.time() -
                            t) / (total_steps + 1)

                if total_steps % FLAGS.print_every == 0:
                    print(
                        "Iter:",
                        '%04d' % iter,
                        "train_loss=",
                        "{:.5f}".format(train_cost),
                        "train_mrr=",
                        "{:.5f}".format(train_mrr),
                        "train_mrr_ema=",
                        "{:.5f}".format(
                            train_shadow_mrr),  # exponential moving average
                        "val_loss=",
                        "{:.5f}".format(val_cost),
                        "val_mrr=",
                        "{:.5f}".format(val_mrr),
                        "val_mrr_ema=",
                        "{:.5f}".format(
                            shadow_mrr),  # exponential moving average
                        "time=",
                        "{:.5f}".format(avg_time))

                    iter += 1
                total_steps += 1

                if total_steps > FLAGS.max_total_steps:
                    break

            if total_steps > FLAGS.max_total_steps:
                break
        print("Optimization Finished!")
        # save model
        saver.save(sess, log_dir() + 'model.ckpt')
        print("save model succeed!")
    else:
        print("==Testing Model==")
        ckpt = tf.train.get_checkpoint_state(log_dir())
        print(log_dir())
        print(ckpt)
        print(ckpt.model_ckeckpoint_path)
        raw_input()
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_ckeckpoint_path)
        else:
            pass

    raw_input()
    if FLAGS.save_embeddings:
        sess.run(val_adj_info.op)

        save_val_embeddings(sess, model, minibatch, FLAGS.validate_batch_size,
                            log_dir())

        if FLAGS.model == "n2v":
            # stopping the gradient for the already trained nodes
            train_ids = tf.constant(
                [[id_map[n]] for n in G.nodes_iter()
                 if not G.node[n]['val'] and not G.node[n]['test']],
                dtype=tf.int32)
            test_ids = tf.constant([[id_map[n]] for n in G.nodes_iter()
                                    if G.node[n]['val'] or G.node[n]['test']],
                                   dtype=tf.int32)
            update_nodes = tf.nn.embedding_lookup(model.context_embeds,
                                                  tf.squeeze(test_ids))
            no_update_nodes = tf.nn.embedding_lookup(model.context_embeds,
                                                     tf.squeeze(train_ids))
            update_nodes = tf.scatter_nd(test_ids, update_nodes,
                                         tf.shape(model.context_embeds))
            no_update_nodes = tf.stop_gradient(
                tf.scatter_nd(train_ids, no_update_nodes,
                              tf.shape(model.context_embeds)))
            model.context_embeds = update_nodes + no_update_nodes
            sess.run(model.context_embeds)

            # run random walks
            from graphsage.utils import run_random_walks
            nodes = [
                n for n in G.nodes_iter()
                if G.node[n]["val"] or G.node[n]["test"]
            ]
            start_time = time.time()
            pairs = run_random_walks(G, nodes, num_walks=50)
            walk_time = time.time() - start_time

            test_minibatch = EdgeMinibatchIterator(
                G,
                id_map,
                placeholders,
                batch_size=FLAGS.batch_size,
                max_degree=FLAGS.max_degree,
                num_neg_samples=FLAGS.neg_sample_size,
                context_pairs=pairs,
                n2v_retrain=True,
                fixed_n2v=True)

            start_time = time.time()
            print("Doing test training for n2v.")
            test_steps = 0
            for epoch in range(FLAGS.n2v_test_epochs):
                test_minibatch.shuffle()
                while not test_minibatch.end():
                    feed_dict = test_minibatch.next_minibatch_feed_dict()
                    feed_dict.update({placeholders['dropout']: FLAGS.dropout})
                    outs = sess.run([
                        model.opt_op, model.loss, model.ranks, model.aff_all,
                        model.mrr, model.outputs1
                    ],
                                    feed_dict=feed_dict)
                    if test_steps % FLAGS.print_every == 0:
                        print("Iter:", '%04d' % test_steps, "train_loss=",
                              "{:.5f}".format(outs[1]), "train_mrr=",
                              "{:.5f}".format(outs[-2]))
                    test_steps += 1
            train_time = time.time() - start_time
            save_val_embeddings(sess,
                                model,
                                minibatch,
                                FLAGS.validate_batch_size,
                                log_dir(),
                                mod="-test")
            print("Total time: ", train_time + walk_time)
            print("Walk time: ", walk_time)
            print("Train time: ", train_time)
Esempio n. 3
0
def train(train_data, test_data=None):
    G = train_data[0]
    features = train_data[1]
    id_map = train_data[2]

    if not features is None:
        # pad with dummy zero vector
        features = np.vstack([features, np.zeros((features.shape[1], ))])

    context_pairs = train_data[3] if FLAGS.random_context else None
    placeholders = construct_placeholders()
    minibatch = EdgeMinibatchIterator(G,
                                      id_map,
                                      placeholders,
                                      batch_size=FLAGS.batch_size,
                                      max_degree=FLAGS.max_degree,
                                      num_neg_samples=FLAGS.neg_sample_size,
                                      context_pairs=context_pairs)
    adj_info_ph = tf.compat.v1.placeholder(tf.int32, shape=minibatch.adj.shape)
    adj_info = tf.Variable(adj_info_ph, trainable=False, name="adj_info")

    if FLAGS.model == 'graphsage_mean':
        # Create model
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   logging=True)
    elif FLAGS.model == 'gcn':
        # Create model
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, 2 * FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, 2 * FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   aggregator_type="gcn",
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   concat=False,
                                   logging=True)

    elif FLAGS.model == 'graphsage_seq':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   identity_dim=FLAGS.identity_dim,
                                   aggregator_type="seq",
                                   model_size=FLAGS.model_size,
                                   logging=True)

    elif FLAGS.model == 'graphsage_maxpool':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   aggregator_type="maxpool",
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   logging=True)
    elif FLAGS.model == 'graphsage_meanpool':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   aggregator_type="meanpool",
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   logging=True)

    elif FLAGS.model == 'n2v':
        model = Node2VecModel(
            placeholders,
            features.shape[0],
            minibatch.deg,
            #2x because graphsage uses concat
            nodevec_dim=2 * FLAGS.dim_1,
            lr=FLAGS.learning_rate)
    else:
        raise Exception('Error: model name unrecognized.')

    config = tf.compat.v1.ConfigProto(
        log_device_placement=FLAGS.log_device_placement)
    config.gpu_options.allow_growth = True
    #config.gpu_options.per_process_gpu_memory_fraction = GPU_MEM_FRACTION
    config.allow_soft_placement = True

    # Initialize WandB experiment
    wandb.init(project='GraphSAGE_trial',
               entity='cvl-liu-01',
               save_code=True,
               tags=['unsupervised'])
    wandb.config.update(flags.FLAGS)

    # Initialize session
    sess = tf.compat.v1.Session(config=config)
    merged = tf.compat.v1.summary.merge_all()
    summary_writer = tf.compat.v1.summary.FileWriter(log_dir(), sess.graph)

    # Init variables
    sess.run(tf.compat.v1.global_variables_initializer(),
             feed_dict={adj_info_ph: minibatch.adj})

    # Init saver
    saver = tf.compat.v1.train.Saver(max_to_keep=8,
                                     keep_checkpoint_every_n_hours=1)

    # Train model
    train_shadow_mrr = None
    val_shadow_mrr = None

    total_steps = 0
    avg_time = 0.0
    epoch_val_costs = []

    train_adj_info = tf.compat.v1.assign(adj_info, minibatch.adj)
    val_adj_info = tf.compat.v1.assign(adj_info, minibatch.test_adj)
    for epoch in range(FLAGS.epochs):
        minibatch.shuffle()

        iter = 0
        print('Epoch: %04d' % (epoch + 1))
        epoch_val_costs.append(0)
        while not minibatch.end():
            # Construct feed dictionary
            feed_dict = minibatch.next_minibatch_feed_dict()
            feed_dict.update({placeholders['dropout']: FLAGS.dropout})

            t = time.time()
            # Training step
            outs = sess.run([
                merged, model.opt_op, model.loss, model.ranks, model.aff_all,
                model.mrr, model.outputs1
            ],
                            feed_dict=feed_dict)
            train_cost = outs[2]
            train_mrr = outs[5]
            if train_shadow_mrr is None:
                train_shadow_mrr = train_mrr  #
            else:
                train_shadow_mrr -= (1 - 0.99) * (train_shadow_mrr - train_mrr)

            # Validation
            if iter % FLAGS.validate_iter == 0:
                sess.run(val_adj_info.op)
                val_cost, ranks, val_mrr, duration = evaluate(
                    sess, model, minibatch, size=FLAGS.validate_batch_size)
                sess.run(train_adj_info.op)
                epoch_val_costs[-1] += val_cost
            if val_shadow_mrr is None:
                val_shadow_mrr = val_mrr
            else:
                val_shadow_mrr -= (1 - 0.99) * (val_shadow_mrr - val_mrr)

            if total_steps % FLAGS.print_every == 0:
                summary_writer.add_summary(outs[0], total_steps)

            # Print results
            avg_time = (avg_time * total_steps + time.time() -
                        t) / (total_steps + 1)

            if total_steps % FLAGS.print_every == 0:
                print(
                    "[%03d/%03d]" % (epoch + 1, FLAGS.epochs),
                    "Iter:",
                    '[%05d/%05d]' % (iter, minibatch.num_training_batches()),
                    "train_loss =",
                    "{:.5f}".format(train_cost),
                    "train_mrr =",
                    "{:.5f}".format(train_mrr),
                    "train_mrr_ema =",
                    "{:.5f}".format(
                        train_shadow_mrr),  # exponential moving average
                    "val_loss =",
                    "{:.5f}".format(val_cost),
                    "val_mrr =",
                    "{:.5f}".format(val_mrr),
                    "val_mrr_ema =",
                    "{:.5f}".format(
                        val_shadow_mrr),  # exponential moving average
                    "time =",
                    "{:.5f}".format(avg_time))

            # W&B Logging
            if FLAGS.wandb_log and iter % FLAGS.wandb_log_iter == 0:
                wandb.log({'train_loss': train_cost, 'epoch': epoch})
                wandb.log({'train_mrr': train_mrr, 'epoch': epoch})
                wandb.log({'train_mrr_ema': train_shadow_mrr, 'epoch': epoch})
                wandb.log({'val_loss': val_cost, 'epoch': epoch})
                wandb.log({'val_mrr': val_mrr, 'epoch': epoch})
                wandb.log({'val_mrr_ema': val_shadow_mrr, 'epoch': epoch})
                wandb.log({'time': avg_time, 'epoch': epoch})

            iter += 1
            total_steps += 1

            if total_steps > FLAGS.max_total_steps:
                print('Max total steps reached!')
                break

        # Save embeddings
        if FLAGS.save_embeddings and epoch % FLAGS.save_embeddings_epoch == 0:
            save_val_embeddings(sess, model, minibatch,
                                FLAGS.validate_batch_size, log_dir())

            # Also report classifier metric on the embedding
            all_tr_res, all_ts_res = osm_classif_eval.evaluate(
                FLAGS.train_prefix, log_dir(), n_iter=FLAGS.classif_n_iter)
            if FLAGS.wandb_log:
                wandb.log(all_tr_res)
                wandb.log(all_ts_res)

        # Save Model checkpoints
        if FLAGS.save_checkpoints and epoch % FLAGS.save_checkpoints_epoch == 0:
            # saver.save(sess, log_dir() + 'model', global_step=1000)
            print('Save model checkpoint:', wandb.run.dir, iter, total_steps,
                  epoch)
            saver.save(
                sess,
                os.path.join(wandb.run.dir,
                             "model-" + str(epoch + 1) + ".ckpt"))

        if total_steps > FLAGS.max_total_steps:
            print('Max total steps reached!')
            break

    print("Optimization Finished!")
    if FLAGS.save_embeddings:
        sess.run(val_adj_info.op)

        save_val_embeddings(sess, model, minibatch, FLAGS.validate_batch_size,
                            log_dir())

        if FLAGS.model == "n2v":
            # stopping the gradient for the already trained nodes
            train_ids = tf.constant(
                [[id_map[n]] for n in G.nodes_iter()
                 if not G.node[n]['val'] and not G.node[n]['test']],
                dtype=tf.int32)
            test_ids = tf.constant([[id_map[n]] for n in G.nodes_iter()
                                    if G.node[n]['val'] or G.node[n]['test']],
                                   dtype=tf.int32)
            update_nodes = tf.nn.embedding_lookup(model.context_embeds,
                                                  tf.squeeze(test_ids))
            no_update_nodes = tf.nn.embedding_lookup(model.context_embeds,
                                                     tf.squeeze(train_ids))
            update_nodes = tf.scatter_nd(test_ids, update_nodes,
                                         tf.shape(model.context_embeds))
            no_update_nodes = tf.stop_gradient(
                tf.scatter_nd(train_ids, no_update_nodes,
                              tf.shape(model.context_embeds)))
            model.context_embeds = update_nodes + no_update_nodes
            sess.run(model.context_embeds)

            # run random walks
            from graphsage.utils import run_random_walks
            nodes = [
                n for n in G.nodes_iter()
                if G.node[n]["val"] or G.node[n]["test"]
            ]
            start_time = time.time()
            pairs = run_random_walks(G, nodes, num_walks=50)
            walk_time = time.time() - start_time

            test_minibatch = EdgeMinibatchIterator(
                G,
                id_map,
                placeholders,
                batch_size=FLAGS.batch_size,
                max_degree=FLAGS.max_degree,
                num_neg_samples=FLAGS.neg_sample_size,
                context_pairs=pairs,
                n2v_retrain=True,
                fixed_n2v=True)

            start_time = time.time()
            print("Doing test training for n2v.")
            test_steps = 0
            for epoch in range(FLAGS.n2v_test_epochs):
                test_minibatch.shuffle()
                while not test_minibatch.end():
                    feed_dict = test_minibatch.next_minibatch_feed_dict()
                    feed_dict.update({placeholders['dropout']: FLAGS.dropout})
                    outs = sess.run([
                        model.opt_op, model.loss, model.ranks, model.aff_all,
                        model.mrr, model.outputs1
                    ],
                                    feed_dict=feed_dict)
                    if test_steps % FLAGS.print_every == 0:
                        print("Iter:", '%04d' % test_steps, "train_loss=",
                              "{:.5f}".format(outs[1]), "train_mrr=",
                              "{:.5f}".format(outs[-2]))
                    test_steps += 1
            train_time = time.time() - start_time
            save_val_embeddings(sess,
                                model,
                                minibatch,
                                FLAGS.validate_batch_size,
                                log_dir(),
                                mod="-test")
            print("Total time: ", train_time + walk_time)
            print("Walk time: ", walk_time)
            print("Train time: ", train_time)
Esempio n. 4
0
def train(train_data, action, test_data=None, regress_fun=run_regression):
    config = tf.ConfigProto(log_device_placement=FLAGS.log_device_placement)
    config.gpu_options.allow_growth = True
    # config.gpu_options.per_process_gpu_memory_fraction = GPU_MEM_FRACTION
    config.allow_soft_placement = True
    # Initialize session
    # 定义作用域,不然会与Controller发生冲突
    with tf.Session(config=config, graph=tf.Graph()) as sess:
        with sess.as_default():
            with sess.graph.as_default():
                # Set random seed
                seed = 123
                np.random.seed(seed)
                tf.set_random_seed(seed)

                G = train_data[0]
                features = train_data[1]
                id_map = train_data[2]

                if not features is None:
                    # pad with dummy zero vector
                    features = np.vstack(
                        [features, np.zeros((features.shape[1], ))])

                context_pairs = train_data[3] if FLAGS.random_context else None
                placeholders = construct_placeholders()
                minibatch = EdgeMinibatchIterator(
                    G,
                    id_map,
                    placeholders,
                    batch_size=FLAGS.batch_size,
                    max_degree=FLAGS.max_degree,
                    num_neg_samples=FLAGS.neg_sample_size,
                    context_pairs=context_pairs)
                adj_info_ph = tf.placeholder(tf.int32,
                                             shape=minibatch.adj.shape)
                adj_info = tf.Variable(adj_info_ph,
                                       trainable=False,
                                       name="adj_info")

                sampler = UniformNeighborSampler(adj_info)  # 邻居采样,方式为随机重排邻居
                state_nums = 2  # Controller定义的状态数量
                layers_num = len(action) // state_nums  # 计算层数
                layer_infos = []
                # 用于指导最终GNN的生成
                layer_infos.append(
                    SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1))
                layer_infos.append(
                    SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_1))
                # 用于NAS的无监督GraphSage
                model = NASUnsupervisedGraphsage(
                    placeholders,
                    features,
                    adj_info,
                    minibatch.deg,
                    layer_infos=layer_infos,
                    action=action,
                    state_nums=state_nums,
                    model_size=FLAGS.model_size,
                    identity_dim=FLAGS.identity_dim,
                    logging=True)

                merged = tf.summary.merge_all()
                summary_writer = tf.summary.FileWriter(log_dir(action),
                                                       sess.graph)

                # Init variables
                sess.run(tf.global_variables_initializer(),
                         feed_dict={adj_info_ph: minibatch.adj})

                # Train model

                train_shadow_mrr = None
                shadow_mrr = None

                total_steps = 0
                avg_time = 0.0
                epoch_val_costs = []

                train_adj_info = tf.assign(adj_info, minibatch.adj)
                val_adj_info = tf.assign(adj_info, minibatch.test_adj)
                for epoch in range(FLAGS.epochs):
                    minibatch.shuffle()

                    iter = 0
                    print('Epoch: %04d' % (epoch + 1))
                    epoch_val_costs.append(0)
                    while not minibatch.end():
                        # Construct feed dictionary
                        feed_dict = minibatch.next_minibatch_feed_dict()
                        feed_dict.update(
                            {placeholders['dropout']: FLAGS.dropout})

                        t = time.time()
                        # Training step
                        outs = sess.run([
                            merged, model.opt_op, model.loss, model.ranks,
                            model.aff_all, model.mrr, model.outputs1
                        ],
                                        feed_dict=feed_dict)
                        train_cost = outs[2]
                        train_mrr = outs[5]
                        if train_shadow_mrr is None:
                            train_shadow_mrr = train_mrr  #
                        else:
                            train_shadow_mrr -= (1 - 0.99) * (
                                train_shadow_mrr - train_mrr)

                        if iter % FLAGS.validate_iter == 0:
                            # Validation
                            sess.run(val_adj_info.op)
                            val_cost, ranks, val_mrr, duration = evaluate(
                                sess,
                                model,
                                minibatch,
                                size=FLAGS.validate_batch_size)
                            sess.run(train_adj_info.op)
                            epoch_val_costs[-1] += val_cost
                        if shadow_mrr is None:
                            shadow_mrr = val_mrr
                        else:
                            shadow_mrr -= (1 - 0.99) * (shadow_mrr - val_mrr)

                        if total_steps % FLAGS.print_every == 0:
                            summary_writer.add_summary(outs[0], total_steps)

                        # Print results
                        avg_time = (avg_time * total_steps + time.time() -
                                    t) / (total_steps + 1)

                        if total_steps % FLAGS.print_every == 0:
                            print(
                                "Iter:",
                                '%04d' % iter,
                                "train_loss=",
                                "{:.5f}".format(train_cost),
                                "train_mrr=",
                                "{:.5f}".format(train_mrr),
                                "train_mrr_ema=",
                                "{:.5f}".format(
                                    train_shadow_mrr
                                ),  # exponential moving average
                                "val_loss=",
                                "{:.5f}".format(val_cost),
                                "val_mrr=",
                                "{:.5f}".format(val_mrr),
                                "val_mrr_ema=",
                                "{:.5f}".format(
                                    shadow_mrr),  # exponential moving average
                                "time=",
                                "{:.5f}".format(avg_time))

                        iter += 1
                        total_steps += 1

                        if total_steps > FLAGS.max_total_steps:
                            break

                    if total_steps > FLAGS.max_total_steps:
                        break

                print("Optimization Finished!")
                sess.run(val_adj_info.op)
                # val_losess,val_mrr, duration = incremental_evaluate(sess, model, minibatch,
                #                                                  FLAGS.batch_size)
                # print("Full validation stats:",
                #       "loss=", "{:.5f}".format(val_losess),
                #       "val_mrr=", "{:.5f}".format(val_mrr),
                #       "time=", "{:.5f}".format(duration))

                if FLAGS.save_embeddings:
                    sess.run(val_adj_info.op)

                    embeds = save_val_embeddings(sess, model, minibatch,
                                                 FLAGS.validate_batch_size,
                                                 log_dir(action))

                    if FLAGS.model == "n2v":
                        # stopping the gradient for the already trained nodes
                        train_ids = tf.constant(
                            [[id_map[n]] for n in G.nodes_iter()
                             if not G.node[n]['val'] and not G.node[n]['test']
                             ],
                            dtype=tf.int32)
                        test_ids = tf.constant(
                            [[id_map[n]] for n in G.nodes_iter()
                             if G.node[n]['val'] or G.node[n]['test']],
                            dtype=tf.int32)
                        update_nodes = tf.nn.embedding_lookup(
                            model.context_embeds, tf.squeeze(test_ids))
                        no_update_nodes = tf.nn.embedding_lookup(
                            model.context_embeds, tf.squeeze(train_ids))
                        update_nodes = tf.scatter_nd(
                            test_ids, update_nodes,
                            tf.shape(model.context_embeds))
                        no_update_nodes = tf.stop_gradient(
                            tf.scatter_nd(train_ids, no_update_nodes,
                                          tf.shape(model.context_embeds)))
                        model.context_embeds = update_nodes + no_update_nodes
                        sess.run(model.context_embeds)

                        # run random walks
                        from graphsage.utils import run_random_walks
                        nodes = [
                            n for n in G.nodes_iter()
                            if G.node[n]["val"] or G.node[n]["test"]
                        ]
                        start_time = time.time()
                        pairs = run_random_walks(G, nodes, num_walks=50)
                        walk_time = time.time() - start_time

                        test_minibatch = EdgeMinibatchIterator(
                            G,
                            id_map,
                            placeholders,
                            batch_size=FLAGS.batch_size,
                            max_degree=FLAGS.max_degree,
                            num_neg_samples=FLAGS.neg_sample_size,
                            context_pairs=pairs,
                            n2v_retrain=True,
                            fixed_n2v=True)

                        start_time = time.time()
                        print("Doing test training for n2v.")
                        test_steps = 0
                        for epoch in range(FLAGS.n2v_test_epochs):
                            test_minibatch.shuffle()
                            while not test_minibatch.end():
                                feed_dict = test_minibatch.next_minibatch_feed_dict(
                                )
                                feed_dict.update(
                                    {placeholders['dropout']: FLAGS.dropout})
                                outs = sess.run([
                                    model.opt_op, model.loss, model.ranks,
                                    model.aff_all, model.mrr, model.outputs1
                                ],
                                                feed_dict=feed_dict)
                                if test_steps % FLAGS.print_every == 0:
                                    print("Iter:", '%04d' % test_steps,
                                          "train_loss=",
                                          "{:.5f}".format(outs[1]),
                                          "train_mrr=",
                                          "{:.5f}".format(outs[-2]))
                                test_steps += 1
                        train_time = time.time() - start_time
                        save_val_embeddings(sess,
                                            model,
                                            minibatch,
                                            FLAGS.validate_batch_size,
                                            log_dir(action),
                                            mod="-test")
                        print("Total time: ", train_time + walk_time)
                        print("Walk time: ", walk_time)
                        print("Train time: ", train_time)
    setting = "val"
    tf.reset_default_graph()
    labels = train_data[4]
    train_ids = [
        n for n in G.nodes() if not G.node[n]['val'] and not G.node[n]['test']
    ]
    test_ids = [n for n in G.nodes() if G.node[n][setting]]
    train_labels = np.array([labels[i] for i in train_ids])
    if train_labels.ndim == 1:
        train_labels = np.expand_dims(train_labels, 1)
    test_labels = np.array([labels[i] for i in test_ids])
    id_map = {}
    with open(log_dir(action) + "/val.txt") as fp:
        for i, line in enumerate(fp):
            id_map[line.strip()] = i
    train_embeds = embeds[[id_map[str(id)] for id in train_ids]]
    test_embeds = embeds[[id_map[str(id)] for id in test_ids]]

    print("Running regression..")
    regress_fun(train_embeds, train_labels, test_embeds, test_labels)
    #用f1指数替换accuracy,此处未做滑动指数平均
    return get_rewards(val_mrr), val_mrr
Esempio n. 5
0
def train(train_data, test_data=None):
    G = train_data[0]
    features = train_data[1]
    id_map = train_data[2]
    homologs = train_data[5]

    if not features is None:
        # pad with dummy zero vector
        features = np.vstack([features, np.zeros((features.shape[1],))])

    context_pairs = train_data[3] if FLAGS.random_context else None
    placeholders = construct_placeholders()
    minibatch = EdgeMinibatchIterator(G, 
            id_map,
            placeholders,
            batch_size=FLAGS.batch_size,
            homologs=homologs,
            max_degree=FLAGS.max_degree, 
            num_neg_samples=FLAGS.neg_sample_size,
            context_pairs = context_pairs)
    adj_info_ph = tf.placeholder(tf.int32, shape=minibatch.adj.shape)
    adj_info = tf.Variable(adj_info_ph, trainable=False, name="adj_info")

    if FLAGS.model == 'graphsage_mean':
        # Create model
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)]

        model = SampleAndAggregate(placeholders, 
                                     features,
                                     adj_info,
                                     minibatch.deg,
                                     layer_infos=layer_infos, 
                                     model_size=FLAGS.model_size,
                                     identity_dim = FLAGS.identity_dim,
                                     homolog_loss=FLAGS.homolog_loss,
                                     homolog_importance=FLAGS.homolog_importance,
                                     logging=True)
    elif FLAGS.model == 'gcn':
        # Create model
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, 2*FLAGS.dim_1),
                            SAGEInfo("node", sampler, FLAGS.samples_2, 2*FLAGS.dim_2)]

        model = SampleAndAggregate(placeholders, 
                                     features,
                                     adj_info,
                                     minibatch.deg,
                                     layer_infos=layer_infos, 
                                     aggregator_type="gcn",
                                     model_size=FLAGS.model_size,
                                     identity_dim = FLAGS.identity_dim,
                                     homolog_loss=FLAGS.homolog_loss,
                                     homolog_importance=FLAGS.homolog_importance,
                                     concat=False,
                                     logging=True)

    elif FLAGS.model == 'graphsage_seq':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)]

        model = SampleAndAggregate(placeholders, 
                                     features,
                                     adj_info,
                                     minibatch.deg,
                                     layer_infos=layer_infos, 
                                     identity_dim = FLAGS.identity_dim,
                                     aggregator_type="seq",
                                     model_size=FLAGS.model_size,
                                     homolog_loss=FLAGS.homolog_loss,
                                     homolog_importance=FLAGS.homolog_importance,
                                     logging=True)

    elif FLAGS.model == 'graphsage_maxpool':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)]

        model = SampleAndAggregate(placeholders, 
                                    features,
                                    adj_info,
                                    minibatch.deg,
                                     layer_infos=layer_infos, 
                                     aggregator_type="maxpool",
                                     model_size=FLAGS.model_size,
                                     identity_dim = FLAGS.identity_dim,
                                     homolog_loss=FLAGS.homolog_loss,
                                     homolog_importance=FLAGS.homolog_importance,
                                     logging=True)
    elif FLAGS.model == 'graphsage_meanpool':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)]

        model = SampleAndAggregate(placeholders, 
                                    features,
                                    adj_info,
                                    minibatch.deg,
                                     layer_infos=layer_infos, 
                                     aggregator_type="meanpool",
                                     model_size=FLAGS.model_size,
                                     identity_dim = FLAGS.identity_dim,
                                     homolog_loss=FLAGS.homolog_loss,
                                     homolog_importance=FLAGS.homolog_importance,
                                     logging=True)

    elif FLAGS.model == 'n2v':
        model = Node2VecModel(placeholders, features.shape[0],
                                       minibatch.deg,
                                       #2x because graphsage uses concat
                                       nodevec_dim=2*FLAGS.dim_1,
                                       lr=FLAGS.learning_rate)
    else:
        raise Exception('Error: model name unrecognized.')

    config = tf.ConfigProto(log_device_placement=FLAGS.log_device_placement)
    config.gpu_options.allow_growth = True
    #config.gpu_options.per_process_gpu_memory_fraction = GPU_MEM_FRACTION
    config.allow_soft_placement = True
    
    # Initialize session
    sess = tf.Session(config=config)
    merged = tf.summary.merge_all()
    summary_writer = tf.summary.FileWriter(log_dir(), sess.graph)
     
    # Init variables
    sess.run(tf.global_variables_initializer(), feed_dict={adj_info_ph: minibatch.adj})
    
    # Train model
    
    train_shadow_mrr = None
    shadow_mrr = None

    total_steps = 0
    avg_time = 0.0
    epoch_val_costs = []

    train_adj_info = tf.assign(adj_info, minibatch.adj)
    val_adj_info = tf.assign(adj_info, minibatch.test_adj)
    for epoch in range(FLAGS.epochs): 
        minibatch.shuffle() 

        iter = 0
        print('Epoch: %04d' % (epoch + 1))
        epoch_val_costs.append(0)
        while not minibatch.end():
            # Construct feed dictionary
            feed_dict = minibatch.next_minibatch_feed_dict()
            feed_dict.update({placeholders['dropout']: FLAGS.dropout})

            t = time.time()
            # Training step
            outs = sess.run([merged, model.opt_op, model.loss, model.ranks, model.aff_all, 
                    model.mrr, model.outputs1], feed_dict=feed_dict)
            train_cost = outs[2]
            train_mrr = outs[5]
            train_embs = outs[-1]
            # import pdb; pdb.set_trace()
            if train_shadow_mrr is None:
                train_shadow_mrr = train_mrr#
            else:
                train_shadow_mrr -= (1-0.99) * (train_shadow_mrr - train_mrr)

            if iter % FLAGS.validate_iter == 0:
                # Validation
                sess.run(val_adj_info.op)
                val_cost, ranks, val_mrr, duration  = evaluate(sess, model, minibatch, size=FLAGS.validate_batch_size)
                sess.run(train_adj_info.op)
                epoch_val_costs[-1] += val_cost
            if shadow_mrr is None:
                shadow_mrr = val_mrr
            else:
                shadow_mrr -= (1-0.99) * (shadow_mrr - val_mrr)

            if total_steps % FLAGS.print_every == 0:
                summary_writer.add_summary(outs[0], total_steps)
    
            # Print results
            avg_time = (avg_time * total_steps + time.time() - t) / (total_steps + 1)

            if total_steps % FLAGS.print_every == 0:
                print("Iter:", '%010d' % iter,
                      "train_loss=", "{:.5f}".format(train_cost),
                      "train_mrr=", "{:.5f}".format(train_mrr), 
                      "train_mrr_ema=", "{:.5f}".format(train_shadow_mrr), # exponential moving average
                      "val_loss=", "{:.5f}".format(val_cost),
                      "val_mrr=", "{:.5f}".format(val_mrr), 
                      "val_mrr_ema=", "{:.5f}".format(shadow_mrr), # exponential moving average
                      "time=", "{:.5f}".format(avg_time))

            iter += 1
            total_steps += 1

            if total_steps % 5000 == 0:
                sess.run(val_adj_info.op)
                save_val_embeddings(sess, model, minibatch, FLAGS.validate_batch_size, log_dir())
                sess.run(train_adj_info.op)
            if total_steps > FLAGS.max_total_steps:
                break

        if total_steps > FLAGS.max_total_steps:
                break
    
    print("Optimization Finished!")
    if FLAGS.save_embeddings:
        sess.run(val_adj_info.op)

        save_val_embeddings(sess, model, minibatch, FLAGS.validate_batch_size, log_dir())
Esempio n. 6
0
def train(train_data, test_data=None, sampler_name='Uniform'):
    
    G = train_data[0]
    features = train_data[1]
    id_map = train_data[2]

    if not features is None:
        # pad with dummy zero vector
        features = np.vstack([features, np.zeros((features.shape[1],))])

    context_pairs = train_data[3] if FLAGS.random_context else None
    placeholders = construct_placeholders()
    minibatch = EdgeMinibatchIterator(G, 
            id_map,
            placeholders, batch_size=FLAGS.batch_size,
            max_degree=FLAGS.max_degree, 
            num_neg_samples=FLAGS.neg_sample_size,
            context_pairs = context_pairs)
    
    adj_info_ph = tf.placeholder(tf.int32, shape=minibatch.adj.shape)
    adj_info = tf.Variable(adj_info_ph, trainable=False, name="adj_info")
    adj_shape = adj_info.get_shape().as_list()

    if FLAGS.model == 'mean_concat':
        # Create model

        if sampler_name == 'Uniform':
            sampler = UniformNeighborSampler(adj_info)
        elif sampler_name == 'ML':
            sampler = MLNeighborSampler(adj_info, features)
        elif sampler_name == 'FastML':
            sampler = FastMLNeighborSampler(adj_info, features)


        #sampler = UniformNeighborSampler(adj_info)
        layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)]

        model = SampleAndAggregate(placeholders, 
                                     features,
                                     adj_info,
                                     minibatch.deg,
                                     concat=True,
                                     layer_infos=layer_infos, 
                                     model_size=FLAGS.model_size,
                                     identity_dim = FLAGS.identity_dim,
                                     logging=True)
    
    elif FLAGS.model == 'mean_add':
        # Create model

        if sampler_name == 'Uniform':
            sampler = UniformNeighborSampler(adj_info)
        elif sampler_name == 'ML':
            sampler = MLNeighborSampler(adj_info, features)
        elif sampler_name == 'FastML':
            sampler = FastMLNeighborSampler(adj_info, features)


        #sampler = UniformNeighborSampler(adj_info)
        layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)]

        model = SampleAndAggregate(placeholders, 
                                     features,
                                     adj_info,
                                     minibatch.deg,
                                     concat=False,
                                     layer_infos=layer_infos, 
                                     model_size=FLAGS.model_size,
                                     identity_dim = FLAGS.identity_dim,
                                     logging=True)

    elif FLAGS.model == 'gcn':

        if sampler_name == 'Uniform':
            sampler = UniformNeighborSampler(adj_info)
        elif sampler_name == 'ML':
            sampler = MLNeighborSampler(adj_info, features)
        elif sampler_name == 'FastML':
            sampler = FastMLNeighborSampler(adj_info, features)

           
        # Create model
        #sampler = UniformNeighborSampler(adj_info)
        layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, 2*FLAGS.dim_1),
                            SAGEInfo("node", sampler, FLAGS.samples_2, 2*FLAGS.dim_2)]

        model = SampleAndAggregate(placeholders, 
                                     features,
                                     adj_info,
                                     minibatch.deg,
                                     layer_infos=layer_infos, 
                                     aggregator_type="gcn",
                                     model_size=FLAGS.model_size,
                                     identity_dim = FLAGS.identity_dim,
                                     concat=False,
                                     logging=True)

    elif FLAGS.model == 'graphsage_seq':
        
        if sampler_name == 'Uniform':
            sampler = UniformNeighborSampler(adj_info)
        elif sampler_name == 'ML':
            sampler = MLNeighborSampler(adj_info, features)
        elif sampler_name == 'FastML':
            sampler = FastMLNeighborSampler(adj_info, features)


        #sampler = UniformNeighborSampler(adj_info)
        layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)]

        model = SampleAndAggregate(placeholders, 
                                     features,
                                     adj_info,
                                     minibatch.deg,
                                     layer_infos=layer_infos, 
                                     identity_dim = FLAGS.identity_dim,
                                     aggregator_type="seq",
                                     model_size=FLAGS.model_size,
                                     logging=True)

    elif FLAGS.model == 'graphsage_maxpool':
        
        if sampler_name == 'Uniform':
            sampler = UniformNeighborSampler(adj_info)
        elif sampler_name == 'ML':
            sampler = MLNeighborSampler(adj_info, features)
        elif sampler_name == 'FastML':
            sampler = FastMLNeighborSampler(adj_info, features)
    
            
        #sampler = UniformNeighborSampler(adj_info)
        layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)]

        model = SampleAndAggregate(placeholders, 
                                    features,
                                    adj_info,
                                    minibatch.deg,
                                     layer_infos=layer_infos, 
                                     aggregator_type="maxpool",
                                     model_size=FLAGS.model_size,
                                     identity_dim = FLAGS.identity_dim,
                                     logging=True)
    elif FLAGS.model == 'graphsage_meanpool':
        
        if sampler_name == 'Uniform':
            sampler = UniformNeighborSampler(adj_info)
        elif sampler_name == 'ML':
            sampler = MLNeighborSampler(adj_info, features)
        elif sampler_name == 'FastML':
            sampler = FastMLNeighborSampler(adj_info, features)
            
        #sampler = UniformNeighborSampler(adj_info)
        layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)]

        model = SampleAndAggregate(placeholders, 
                                    features,
                                    adj_info,
                                    minibatch.deg,
                                     layer_infos=layer_infos, 
                                     aggregator_type="meanpool",
                                     model_size=FLAGS.model_size,
                                     identity_dim = FLAGS.identity_dim,
                                     logging=True)

    elif FLAGS.model == 'n2v':
        model = Node2VecModel(placeholders, features.shape[0],
                                       minibatch.deg,
                                       #2x because graphsage uses concat
                                       nodevec_dim=2*FLAGS.dim_1,
                                       lr=FLAGS.learning_rate)
    else:
        raise Exception('Error: model name unrecognized.')

    config = tf.ConfigProto(log_device_placement=FLAGS.log_device_placement)
    config.gpu_options.allow_growth = True
    #config.gpu_options.per_process_gpu_memory_fraction = GPU_MEM_FRACTION
    config.allow_soft_placement = True
    
    # Initialize session
    sess = tf.Session(config=config)
    merged = tf.summary.merge_all()
    summary_writer = tf.summary.FileWriter(log_dir(sampler_name), sess.graph)
    
    
    # Save model
    saver = tf.train.Saver()
    model_path =  './model/unsup-' + FLAGS.train_prefix.split('/')[-1] + '-' + FLAGS.model_prefix + '-' + sampler_name
    model_path += "/{model:s}_{model_size:s}_{lr:0.4f}/".format(
            model=FLAGS.model,
            model_size=FLAGS.model_size,
            lr=FLAGS.learning_rate)

    if not os.path.exists(model_path):
        os.makedirs(model_path)

    
    # Init variables
    sess.run(tf.global_variables_initializer(), feed_dict={adj_info_ph: minibatch.adj})
    
    
    # Restore params of ML sampler model
    if sampler_name == 'ML' or sampler_name == 'FastML':
        sampler_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="MLsampler")
        #pdb.set_trace() 
        saver_sampler = tf.train.Saver(var_list=sampler_vars)
        sampler_model_path = './model/MLsampler-unsup-' + FLAGS.train_prefix.split('/')[-1] + '-' + FLAGS.model_prefix
        sampler_model_path += "/{model:s}_{model_size:s}_{lr:0.4f}/".format(
            model=FLAGS.model,
            model_size=FLAGS.model_size,
            lr=FLAGS.learning_rate)

        saver_sampler.restore(sess, sampler_model_path + 'model.ckpt')

   
    # Train model
    
    train_shadow_mrr = None
    shadow_mrr = None

    total_steps = 0
    avg_time = 0.0
    epoch_val_costs = []

    train_adj_info = tf.assign(adj_info, minibatch.adj)
    val_adj_info = tf.assign(adj_info, minibatch.test_adj)

    val_cost_ = []
    val_mrr_ = []
    shadow_mrr_ = []
    duration_ = []
    
    ln_acc = sparse.csr_matrix((adj_shape[0], adj_shape[0]), dtype=np.float32)
    lnc_acc = sparse.csr_matrix((adj_shape[0], adj_shape[0]), dtype=np.int32)
    
    ln_acc = ln_acc.tolil()
    lnc_acc = lnc_acc.tolil()

    for epoch in range(FLAGS.epochs): 
        minibatch.shuffle() 

        iter = 0
        print('Epoch: %04d' % (epoch + 1))
        epoch_val_costs.append(0)
        
        while not minibatch.end():
            # Construct feed dictionary
            feed_dict = minibatch.next_minibatch_feed_dict()

            if feed_dict.values()[0] != FLAGS.batch_size:
                break

            feed_dict.update({placeholders['dropout']: FLAGS.dropout})

            t = time.time()
            # Training step
            outs = sess.run([merged, model.opt_op, model.loss, model.ranks, model.aff_all, 
                    model.mrr, model.outputs1, model.loss_node, model.loss_node_count], feed_dict=feed_dict)
            train_cost = outs[2]
            train_mrr = outs[5]
            
            if train_shadow_mrr is None:
                train_shadow_mrr = train_mrr#
            else:
                train_shadow_mrr -= (1-0.99) * (train_shadow_mrr - train_mrr)

            if iter % FLAGS.validate_iter == 0:
                # Validation
                sess.run(val_adj_info.op)
                val_cost, ranks, val_mrr, duration  = evaluate(sess, model, minibatch, size=FLAGS.validate_batch_size)
                sess.run(train_adj_info.op)
                epoch_val_costs[-1] += val_cost
            
            if shadow_mrr is None:
                shadow_mrr = val_mrr
            else:
                shadow_mrr -= (1-0.99) * (shadow_mrr - val_mrr)
            
            val_cost_.append(val_cost)
            val_mrr_.append(val_mrr)
            shadow_mrr_.append(shadow_mrr)
            duration_.append(duration)


            if total_steps % FLAGS.print_every == 0:
                summary_writer.add_summary(outs[0], total_steps)
    
            # Print results
            avg_time = (avg_time * total_steps + time.time() - t) / (total_steps + 1)

            if total_steps % FLAGS.print_every == 0:
                print("Iter:", '%04d' % iter, 
                      "train_loss=", "{:.5f}".format(train_cost),
                      "train_mrr=", "{:.5f}".format(train_mrr), 
                      "train_mrr_ema=", "{:.5f}".format(train_shadow_mrr), # exponential moving average
                      "val_loss=", "{:.5f}".format(val_cost),
                      "val_mrr=", "{:.5f}".format(val_mrr), 
                      "val_mrr_ema=", "{:.5f}".format(shadow_mrr), # exponential moving average
                      "time=", "{:.5f}".format(avg_time))

            
            ln = outs[7].values
            ln_idx = outs[7].indices
            ln_acc[ln_idx[:,0], ln_idx[:,1]] += ln

            lnc = outs[8].values
            lnc_idx = outs[8].indices
            lnc_acc[lnc_idx[:,0], lnc_idx[:,1]] += lnc
                
                
            iter += 1
            total_steps += 1

            if total_steps > FLAGS.max_total_steps:
                break

        if total_steps > FLAGS.max_total_steps:
                break


    print("Validation per epoch in training")
    for ep in range(FLAGS.epochs):
        print("Epoch: %04d"%ep, " val_cost={:.5f}".format(val_cost_[ep]), " val_mrr={:.5f}".format(val_mrr_[ep]), " val_mrr_ema={:.5f}".format(shadow_mrr_[ep]), " duration={:.5f}".format(duration_[ep]))
 
    print("Optimization Finished!")
    
    # Save model
    save_path = saver.save(sess, model_path+'model.ckpt')
    print ('model is saved at %s'%save_path)


    # Save loss node and count
    loss_node_path = './loss_node/unsup-' + FLAGS.train_prefix.split('/')[-1] + '-' + FLAGS.model_prefix + '-' + sampler_name
    loss_node_path += "/{model:s}_{model_size:s}_{lr:0.4f}/".format(
            model=FLAGS.model,
            model_size=FLAGS.model_size,
            lr=FLAGS.learning_rate)
    if not os.path.exists(loss_node_path):
        os.makedirs(loss_node_path)

    loss_node = sparse.save_npz(loss_node_path + 'loss_node.npz', sparse.csr_matrix(ln_acc))
    loss_node_count = sparse.save_npz(loss_node_path + 'loss_node_count.npz', sparse.csr_matrix(lnc_acc))
    print ('loss and count per node is saved at %s'%loss_node_path)    
    
    
    
    if FLAGS.save_embeddings:
        sess.run(val_adj_info.op)

        save_val_embeddings(sess, model, minibatch, FLAGS.validate_batch_size, log_dir(sampler_name))

        if FLAGS.model == "n2v":
            # stopping the gradient for the already trained nodes
            train_ids = tf.constant([[id_map[n]] for n in G.nodes_iter() if not G.node[n]['val'] and not G.node[n]['test']],
                    dtype=tf.int32)
            test_ids = tf.constant([[id_map[n]] for n in G.nodes_iter() if G.node[n]['val'] or G.node[n]['test']], 
                    dtype=tf.int32)
            update_nodes = tf.nn.embedding_lookup(model.context_embeds, tf.squeeze(test_ids))
            no_update_nodes = tf.nn.embedding_lookup(model.context_embeds,tf.squeeze(train_ids))
            update_nodes = tf.scatter_nd(test_ids, update_nodes, tf.shape(model.context_embeds))
            no_update_nodes = tf.stop_gradient(tf.scatter_nd(train_ids, no_update_nodes, tf.shape(model.context_embeds)))
            model.context_embeds = update_nodes + no_update_nodes
            sess.run(model.context_embeds)

            # run random walks
            from graphsage.utils import run_random_walks
            nodes = [n for n in G.nodes_iter() if G.node[n]["val"] or G.node[n]["test"]]
            start_time = time.time()
            pairs = run_random_walks(G, nodes, num_walks=50)
            walk_time = time.time() - start_time

            test_minibatch = EdgeMinibatchIterator(G, 
                id_map,
                placeholders, batch_size=FLAGS.batch_size,
                max_degree=FLAGS.max_degree, 
                num_neg_samples=FLAGS.neg_sample_size,
                context_pairs = pairs,
                n2v_retrain=True,
                fixed_n2v=True)
            
            start_time = time.time()
            print("Doing test training for n2v.")
            test_steps = 0
            for epoch in range(FLAGS.n2v_test_epochs):
                test_minibatch.shuffle()
                while not test_minibatch.end():
                    feed_dict = test_minibatch.next_minibatch_feed_dict()
                    feed_dict.update({placeholders['dropout']: FLAGS.dropout})
                    outs = sess.run([model.opt_op, model.loss, model.ranks, model.aff_all, 
                        model.mrr, model.outputs1], feed_dict=feed_dict)
                    if test_steps % FLAGS.print_every == 0:
                        print("Iter:", '%04d' % test_steps, 
                              "train_loss=", "{:.5f}".format(outs[1]),
                              "train_mrr=", "{:.5f}".format(outs[-2]))
                    test_steps += 1
            train_time = time.time() - start_time
            save_val_embeddings(sess, model, minibatch, FLAGS.validate_batch_size, log_dir(sampler_name), mod="-test")
            print("Total time: ", train_time+walk_time)
            print("Walk time: ", walk_time)
            print("Train time: ", train_time)
Esempio n. 7
0
def train(train_data,
          log_dir,
          Theta=None,
          fixed_neigh_weights=None,
          test_data=None,
          neg_sample_weights=None):
    G = train_data[0]
    Gneg = Graph_complement(G)
    features = train_data[1]
    id_map = train_data[2]
    Adj_mat = Edges_to_Adjacency_mat(G.edges(), len(G.nodes()))
    #     print('A in unsup: ', Adj_mat)
    #     negAdj_mat = Edges_to_Adjacency_mat(Gneg.edges(), len(Gneg.nodes()))
    FLAGS.batch_size = batch_size_def(len(G.edges()))
    FLAGS.negbatch_size = batch_size_def(len(Gneg.edges()))
    FLAGS.samples_1 = min(25, len(G.nodes()))
    FLAGS.samples_2 = min(10, len(G.nodes()))
    FLAGS.negsamples_1 = min(25, len(Gneg.nodes()))
    FLAGS.negsamples_2 = min(10, len(Gneg.nodes()))
    FLAGS.neg_sample_size = FLAGS.batch_size

    if (len(G.nodes()) <= small_big_threshold):
        FLAGS.model_size = 'small'
    else:
        FLAGS.model_size = 'big'

    if not features is None:
        # pad with dummy zero vector
        features = np.vstack([features, np.zeros((features.shape[1], ))])

    context_pairs = train_data[3] if FLAGS.random_context else None
    placeholders = construct_placeholders()
    #print('placeholders: ', placeholders)
    minibatch = EdgeMinibatchIterator(G,
                                      id_map,
                                      placeholders,
                                      batch_size=FLAGS.batch_size,
                                      max_degree=FLAGS.max_degree,
                                      num_neg_samples=FLAGS.neg_sample_size,
                                      context_pairs=context_pairs)
    aggbatch_size = len(minibatch.agg_batch_Z1)
    negaggbatch_size = len(minibatch.agg_batch_Z3)
    negminibatch = EdgeMinibatchIterator(Gneg,
                                         id_map,
                                         placeholders,
                                         batch_size=FLAGS.negbatch_size,
                                         max_degree=FLAGS.max_degree,
                                         num_neg_samples=FLAGS.neg_sample_size,
                                         context_pairs=context_pairs)

    #adj_info = tf.Variable(placeholders['adj_info_ph'], trainable=False, name="adj_info")
    adj_info_ph = tf.placeholder(tf.int32,
                                 shape=minibatch.adj.shape,
                                 name="adj_info_ph")
    adj_info = tf.Variable(adj_info_ph, trainable=False, name="adj_info")

    if FLAGS.model == 'graphsage_mean':
        # Create model
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1,
                     FLAGS.negsamples_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2,
                     FLAGS.negsamples_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   Adj_mat=Adj_mat,
                                   non_edges=Gneg.edges(),
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   logging=True,
                                   fixed_theta_1=Theta,
                                   fixed_neigh_weights=fixed_neigh_weights,
                                   neg_sample_weights=neg_sample_weights,
                                   aggbatch_size=aggbatch_size,
                                   negaggbatch_size=negaggbatch_size)
    elif FLAGS.model == 'gcn':
        # Create model
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, 2 * FLAGS.dim_1,
                     FLAGS.negsamples_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, 2 * FLAGS.dim_2,
                     FLAGS.negsamples_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   aggregator_type="gcn",
                                   Adj_mat=Adj_mat,
                                   non_edges=Gneg.edges(),
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   concat=False,
                                   logging=True,
                                   fixed_theta_1=Theta,
                                   fixed_neigh_weights=fixed_neigh_weights,
                                   neg_sample_weights=neg_sample_weights,
                                   aggbatch_size=aggbatch_size,
                                   negaggbatch_size=negaggbatch_size)

    elif FLAGS.model == 'graphsage_seq':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1,
                     FLAGS.negsamples_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2,
                     FLAGS.negsamples_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   identity_dim=FLAGS.identity_dim,
                                   aggregator_type="seq",
                                   Adj_mat=Adj_mat,
                                   non_edges=Gneg.edges(),
                                   model_size=FLAGS.model_size,
                                   logging=True,
                                   fixed_theta_1=Theta,
                                   fixed_neigh_weights=fixed_neigh_weights,
                                   neg_sample_weights=neg_sample_weights,
                                   aggbatch_size=aggbatch_size,
                                   negaggbatch_size=negaggbatch_size)

    elif FLAGS.model == 'graphsage_maxpool':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1,
                     FLAGS.negsamples_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2,
                     FLAGS.negsamples_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   aggregator_type="maxpool",
                                   Adj_mat=Adj_mat,
                                   non_edges=Gneg.edges(),
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   logging=True,
                                   fixed_theta_1=Theta,
                                   fixed_neigh_weights=fixed_neigh_weights,
                                   neg_sample_weights=neg_sample_weights,
                                   aggbatch_size=aggbatch_size,
                                   negaggbatch_size=negaggbatch_size)

    elif FLAGS.model == 'graphsage_meanpool':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1,
                     FLAGS.negsamples_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2,
                     FLAGS.negsamples_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   aggregator_type="meanpool",
                                   Adj_mat=Adj_mat,
                                   non_edges=Gneg.edges(),
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   logging=True,
                                   fixed_theta_1=Theta,
                                   fixed_neigh_weights=fixed_neigh_weights,
                                   neg_sample_weights=neg_sample_weights,
                                   aggbatch_size=aggbatch_size,
                                   negaggbatch_size=negaggbatch_size)

    elif FLAGS.model == 'n2v':
        model = Node2VecModel(
            placeholders,
            features.shape[0],
            minibatch.deg,
            #2x because graphsage uses concat
            nodevec_dim=2 * FLAGS.dim_1,
            lr=FLAGS.learning_rate)
    else:
        raise Exception('Error: model name unrecognized.')

    config = tf.ConfigProto(log_device_placement=FLAGS.log_device_placement)
    config.gpu_options.allow_growth = True
    #config.gpu_options.per_process_gpu_memory_fraction = GPU_MEM_FRACTION
    config.allow_soft_placement = True

    # Initialize session
    sess = tf.Session(config=config)
    merged = tf.summary.merge_all()
    #summary_writer = tf.summary.FileWriter(log_dir(), sess.graph)

    # Init variables
    #minibatch.adj = minibatch.adj.astype(np.int32)
    #print('minibatch.adj.shape: %s, dtype: %s' % (minibatch.adj.shape, np.ndarray.dtype(minibatch.adj)))
    sess.run(tf.global_variables_initializer(),
             feed_dict={adj_info_ph: minibatch.adj})

    # Train model

    train_shadow_mrr = None
    shadow_mrr = None

    total_steps = 0
    avg_time = 0.0
    epoch_val_costs = []
    train_adj_info = tf.assign(adj_info, minibatch.adj)
    val_adj_info = tf.assign(adj_info, minibatch.test_adj)

    print_flag = False
    if (Theta is None):
        print_flag = True

    for epoch in range(FLAGS.epochs):

        minibatch.shuffle()
        negminibatch.shuffle()
        iter = 0
        #         print('Epoch: %04d' % (epoch + 1))
        epoch_val_costs.append(0)
        if (FLAGS.model_size == 'big'):
            whichbatch = minibatch
        elif (len(G.edges()) > len(Gneg.edges())):
            whichbatch = minibatch
            opwhichbatch = negminibatch
        else:
            whichbatch = negminibatch
            opwhichbatch = minibatch

        while (not whichbatch.end()):
            if (FLAGS.model_size == 'small' and opwhichbatch.end()):
                opwhichbatch.shuffle()
            # Construct feed dictionary

            feed_dict = minibatch.next_minibatch_feed_dict()
            negfeed_dict = negminibatch.next_minibatch_feed_dict()

            if (True):
                feed_dict.update({
                    placeholders['negbatch1']:
                    negfeed_dict[placeholders['batch1']]
                })
                feed_dict.update({
                    placeholders['negbatch2']:
                    negfeed_dict[placeholders['batch2']]
                })
                feed_dict.update({
                    placeholders['negbatch_size']:
                    negfeed_dict[placeholders['batch_size']]
                })
            else:
                batch1 = feed_dict[placeholders['batch1']]
                feed_dict.update({placeholders['negbatch1']: batch1})
                feed_dict.update({
                    placeholders['negbatch2']:
                    negminibatch.feed_dict_negbatch(batch1)
                })
                feed_dict.update({
                    placeholders['negbatch_size']:
                    negfeed_dict[placeholders['batch_size']]
                })

            if (Theta is not None):
                break
            if (total_steps == 0):
                print_summary(model, feed_dict, sess, Adj_mat, 'first',
                              print_flag)

            t = time.time()
            # Training step
            outs = sess.run(
                [
                    merged, model.opt_op, model.loss, model.ranks,
                    model.aff_all, model.mrr, model.outputs1, model.outputs2,
                    model.negoutputs1, model.negoutputs2
                ],
                feed_dict=feed_dict
            )  #, model.current_similarity   , model.learned_vars['theta_1']  # , model.aggregation
            train_cost = outs[2]
            train_mrr = outs[5]

            #             Z = outs[8]
            #             outs = np.concatenate((outs[6],outs[7]), axis=0)
            #             indices = np.concatenate((feed_dict[placeholders['batch1']],feed_dict[placeholders['batch2']]))
            #             Z = np.zeros((len(G.nodes()),FLAGS.dim_2*2))
            #             for node in G.nodes():
            #                 Z[node,:] = outs[int(np.argwhere(indices==node)[0]),:]#[0:len(G.nodes())] # 8
            #             print('Z shape: ', Z.shape)

            #             if train_shadow_mrr is None:
            #                 train_shadow_mrr = train_mrr #
            #             else:
            #                 train_shadow_mrr -= (1-0.99) * (train_shadow_mrr - train_mrr)
            #
            #             if iter % FLAGS.validate_iter == 0:
            #                 # Validation
            #                 sess.run(val_adj_info.op)
            #                 val_cost, ranks, val_mrr, duration = evaluate(sess, model, minibatch, size=FLAGS.validate_batch_size, negminibatch_iter=negminibatch, placeholders=placeholders)
            #                 sess.run(train_adj_info.op)
            #                 epoch_val_costs[-1] += val_cost
            #
            #             if shadow_mrr is None:
            #                 shadow_mrr = val_mrr
            #             else:
            #                 shadow_mrr -= (1-0.99) * (shadow_mrr - val_mrr)

            #             if total_steps % FLAGS.print_every == 0:
            #                 summary_writer.add_summary(outs[0], total_steps)

            # Print results
            avg_time = (avg_time * total_steps + time.time() -
                        t) / (total_steps + 1)

            if total_steps % FLAGS.print_every == 0:
                print('loss: ', outs[2])
#                 print("Iter:", '%04d' % iter,
#                       "train_loss=", "{:.5f}".format(train_cost),
#                       "train_mrr=", "{:.5f}".format(train_mrr),
#                       "train_mrr_ema=", "{:.5f}".format(train_shadow_mrr), # exponential moving average
#                       "val_loss=", "{:.5f}".format(val_cost),
#                       "val_mrr=", "{:.5f}".format(val_mrr),
#                       "val_mrr_ema=", "{:.5f}".format(shadow_mrr), # exponential moving average
#                       "time=", "{:.5f}".format(avg_time))

#             similarity_weights = outs[7]
#             [new_adj_info, new_batch_edges] = sess.run([model.adj_info, \
#                                                                 model.new_batch_edges], \
#                                                                   feed_dict=feed_dict)
#             minibatch.graph_update(new_adj_info, new_batch_edges)

            iter += 1
            total_steps += 1

            if total_steps > FLAGS.max_total_steps:
                break

        if (Theta is not None):
            break
        if total_steps > FLAGS.max_total_steps:
            break


#     print("SGD Optimization Finished!")

#     feed_dict = dict()
#     feed_dict.update({placeholders['batch_size'] : len(G.nodes())})

#     minibatch = EdgeMinibatchIterator(G,
#             id_map,
#             placeholders, batch_size = len(G.edges()),
#             max_degree = FLAGS.max_degree,
#             num_neg_samples = FLAGS.neg_sample_size,
#             context_pairs = context_pairs)
#     feed_dict = minibatch.next_minibatch_feed_dict()
#     _, Z = sess.run([merged, model.aggregation], feed_dict=feed_dict) #, model.concat   ,     aggregator_cls.vars

#     Z_tilde = np.repeat(Z, [len(G.nodes())], axis=0)
#     Z_tilde_tilde = np.tile(Z, (len(G.nodes()),1))
#     final_theta_1 = Quadratic_SDP_solver (Z_tilde, Z_tilde_tilde, FLAGS.dim_1*2, FLAGS.dim_2*2)
# #     Z_centralized = Z - np.mean(Z, axis=0)
# #     final_adj_matrix = np.abs(np.matmul(Z_centralized, np.matmul(final_theta_1, np.transpose(Z_centralized))))
#     final_adj_matrix = np.matmul(Z, np.matmul(final_theta_1, np.transpose(Z)))
#     feed_dict = minibatch.next_minibatch_feed_dict()
#     feed_dict.update({placeholders['dropout']: FLAGS.dropout})
#     feed_dict.update({placeholders['all_nodes']: all_nodes})

#     FLAGS.batch_size = len(G.edges())
#     FLAGS.negbatch_size = len(Gneg.edges())
#     FLAGS.samples_1 = len(G.nodes())
#     FLAGS.samples_2 = len(G.nodes())
#     FLAGS.negsamples_1 = len(Gneg.nodes())
#     FLAGS.negsamples_2 = len(Gneg.nodes())
#     feed_dict = minibatch.batch_feed_dict(G.edges())
#     negfeed_dict = negminibatch.batch_feed_dict(Gneg.edges())
#     feed_dict.update({placeholders['negbatch1']: negfeed_dict[placeholders['batch1']]})
#     feed_dict.update({placeholders['negbatch2']: negfeed_dict[placeholders['batch2']]})
#     feed_dict.update({placeholders['negbatch_size']: negfeed_dict[placeholders['batch_size']]})
#     if(outs is not None):
#         print('loss: ', outs[2])

    final_adj_matrix, final_theta_1, Z, loss, U = print_summary(
        model, feed_dict, sess, Adj_mat, 'last', print_flag)

    #print('Z shape: ', Z.shape)

    if FLAGS.save_embeddings:
        sess.run(val_adj_info.op)

        save_val_embeddings(sess, model, minibatch, FLAGS.validate_batch_size,
                            log_dir())

        if FLAGS.model == "n2v":
            # stopping the gradient for the already trained nodes
            train_ids = tf.constant(
                [[id_map[n]] for n in G.nodes_iter()
                 if not G.node[n]['val'] and not G.node[n]['test']],
                dtype=tf.int32)
            test_ids = tf.constant([[id_map[n]] for n in G.nodes_iter()
                                    if G.node[n]['val'] or G.node[n]['test']],
                                   dtype=tf.int32)
            update_nodes = tf.nn.embedding_lookup(model.context_embeds,
                                                  tf.squeeze(test_ids))
            no_update_nodes = tf.nn.embedding_lookup(model.context_embeds,
                                                     tf.squeeze(train_ids))
            update_nodes = tf.scatter_nd(test_ids, update_nodes,
                                         tf.shape(model.context_embeds))
            no_update_nodes = tf.stop_gradient(
                tf.scatter_nd(train_ids, no_update_nodes,
                              tf.shape(model.context_embeds)))
            model.context_embeds = update_nodes + no_update_nodes
            sess.run(model.context_embeds)

            # run random walks
            from graphsage.utils import run_random_walks
            nodes = [
                n for n in G.nodes_iter()
                if G.node[n]["val"] or G.node[n]["test"]
            ]
            start_time = time.time()
            pairs = run_random_walks(G, nodes, num_walks=50)
            walk_time = time.time() - start_time

            test_minibatch = EdgeMinibatchIterator(
                G,
                id_map,
                placeholders,
                batch_size=FLAGS.batch_size,
                max_degree=FLAGS.max_degree,
                num_neg_samples=FLAGS.neg_sample_size,
                context_pairs=pairs,
                n2v_retrain=True,
                fixed_n2v=True)

            start_time = time.time()
            print("Doing test training for n2v.")
            test_steps = 0
            for epoch in range(FLAGS.n2v_test_epochs):
                test_minibatch.shuffle()
                while not test_minibatch.end():
                    feed_dict = test_minibatch.next_minibatch_feed_dict()
                    feed_dict.update({placeholders['dropout']: FLAGS.dropout})
                    outs = sess.run([
                        model.opt_op, model.loss, model.ranks, model.aff_all,
                        model.mrr, model.outputs1
                    ],
                                    feed_dict=feed_dict)
                    if test_steps % FLAGS.print_every == 0:
                        print("Iter:", '%04d' % test_steps, "train_loss=",
                              "{:.5f}".format(outs[1]), "train_mrr=",
                              "{:.5f}".format(outs[-2]))
                    test_steps += 1
            train_time = time.time() - start_time
            save_val_embeddings(sess,
                                model,
                                minibatch,
                                FLAGS.validate_batch_size,
                                log_dir(),
                                mod="-test")
            print("Total time: ", train_time + walk_time)
            print("Walk time: ", walk_time)
            print("Train time: ", train_time)
    #del adj_info_ph, adj_info,placeholders
    return final_adj_matrix, G, final_theta_1, Z, loss, U  #, learned_vars
def train(train_data, test_data=None):
    G = train_data[0]
    features = train_data[1]
    id_map = train_data[2]
    prob_matrix = train_data[5]

    if not features is None:
        # pad with dummy zero vector
        features = np.vstack([features, np.zeros((features.shape[1], ))])

    context_pairs = train_data[3] if FLAGS.random_context else None
    placeholders = construct_placeholders()
    minibatch = EdgeMinibatchIterator(G,
                                      id_map,
                                      placeholders,
                                      prob_matrix,
                                      batch_size=FLAGS.batch_size,
                                      max_degree=FLAGS.max_degree,
                                      num_neg_samples=FLAGS.neg_sample_size,
                                      context_pairs=context_pairs)
    adj_info_ph = tf.placeholder(tf.int32, shape=minibatch.adj.shape)
    adj_info = tf.Variable(adj_info_ph, trainable=False, name="adj_info")

    if FLAGS.model == 'graphsage_mean':
        # Create model
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   logging=True)
    elif FLAGS.model == 'gcn':
        # Create model
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, 2 * FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, 2 * FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   aggregator_type="gcn",
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   concat=False,
                                   logging=True)

    elif FLAGS.model == 'graphsage_seq':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   identity_dim=FLAGS.identity_dim,
                                   aggregator_type="seq",
                                   model_size=FLAGS.model_size,
                                   logging=True)

    elif FLAGS.model == 'graphsage_maxpool':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   aggregator_type="maxpool",
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   logging=True)
    elif FLAGS.model == 'graphsage_meanpool':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   aggregator_type="meanpool",
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   logging=True)

    elif FLAGS.model == 'n2v':
        model = Node2VecModel(
            placeholders,
            features.shape[0],
            minibatch.deg,
            #2x because graphsage uses concat
            nodevec_dim=2 * FLAGS.dim_1,
            lr=FLAGS.learning_rate)
    else:
        raise Exception('Error: model name unrecognized.')

    config = tf.ConfigProto(log_device_placement=FLAGS.log_device_placement)
    config.gpu_options.allow_growth = True
    #config.gpu_options.per_process_gpu_memory_fraction = GPU_MEM_FRACTION
    config.allow_soft_placement = True

    # Initialize session
    sess = tf.Session(config=config)
    merged = tf.summary.merge_all()
    summary_writer = tf.summary.FileWriter(log_dir(), sess.graph)

    # Init variables
    sess.run(tf.global_variables_initializer(),
             feed_dict={adj_info_ph: minibatch.adj})

    # Train model

    train_shadow_mrr = None
    shadow_mrr = None

    total_steps = 0
    avg_time = 0.0
    epoch_val_costs = []

    train_adj_info = tf.assign(adj_info, minibatch.adj)
    val_adj_info = tf.assign(adj_info, minibatch.test_adj)

    min_train_loss = 0.0
    fp = open(log_dir() + 'epoch_train_loss.txt', 'w')
    loss_list = []
    epoch_list = []

    for epoch in range(FLAGS.epochs):
        minibatch.shuffle()

        epoch_train_loss = 0.0

        iter = 0
        print('Epoch: %04d' % (epoch + 1))
        epoch_val_costs.append(0)
        while not minibatch.end():
            # Construct feed dictionary
            feed_dict = minibatch.next_minibatch_feed_dict(layer_infos)
            feed_dict.update({placeholders['dropout']: FLAGS.dropout})

            t = time.time()
            # Training step
            outs = sess.run([
                merged, model.opt_op, model.loss, model.ranks, model.aff_all,
                model.mrr, model.outputs1
            ],
                            feed_dict=feed_dict)
            train_cost = outs[2]
            train_mrr = outs[5]

            epoch_train_loss += train_cost

            if train_shadow_mrr is None:
                train_shadow_mrr = train_mrr  #
            else:
                train_shadow_mrr -= (1 - 0.99) * (train_shadow_mrr - train_mrr)

            if total_steps % FLAGS.print_every == 0:
                summary_writer.add_summary(outs[0], total_steps)

            # Print results
            avg_time = (avg_time * total_steps + time.time() -
                        t) / (total_steps + 1)

            if total_steps % FLAGS.print_every == 0:
                print(
                    "Iter:",
                    '%04d' % iter,
                    "train_loss=",
                    "{:.5f}".format(train_cost),
                    "train_mrr=",
                    "{:.5f}".format(train_mrr),
                    "train_mrr_ema=",
                    "{:.5f}".format(
                        train_shadow_mrr),  # exponential moving average
                    "time=",
                    "{:.5f}".format(avg_time))

            iter += 1
            total_steps += 1

            if total_steps > FLAGS.max_total_steps:
                break

        if total_steps > FLAGS.max_total_steps:
            break

        epoch_train_loss = epoch_train_loss / iter

        print('epoch: ' + str(epoch) + '\t' + 'train_loss: ' +
              str(epoch_train_loss) + '\n')
        fp.write('epoch: ' + str(epoch) + '\t' + 'train_loss: ' +
                 str(epoch_train_loss) + '\n')
        loss_list.append(epoch_train_loss)
        epoch_list.append(epoch)

        if epoch == 0:
            min_train_loss = epoch_train_loss
            sess.run(val_adj_info.op)
            node_embeddings, node_list = save_val_embeddings(
                sess, model, minibatch, FLAGS.validate_batch_size, log_dir(),
                layer_infos)
        elif epoch_train_loss < min_train_loss:
            min_train_loss = epoch_train_loss
            sess.run(val_adj_info.op)
            node_embeddings, node_list = save_val_embeddings(
                sess, model, minibatch, FLAGS.validate_batch_size, log_dir(),
                layer_infos)

    print("Optimization Finished!")

    fp.close()

    np.save(log_dir() + "val.npy", node_embeddings)
    with open(log_dir() + "val.txt", "w") as fp:
        fp.write("\n".join(map(str, node_list)))

    plt.plot(epoch_list, loss_list)
    plt.savefig(log_dir() + 'loss_trend.png')
    plt.clf()

    ## t-SNE plot
    ## creating t-SNE embedding feature
    #tsne_feat=TSNE(n_components=2).fit_transform(node_embeddings)
    tsne_feat = node_embeddings

    ## plotting t-SNE embedding
    tsne_df = pd.DataFrame(columns=['tsne0', 'tsne1'])
    tsne_df['tsne0'] = tsne_feat[:, 0]
    tsne_df['tsne1'] = tsne_feat[:, 1]
    plt.figure(figsize=(16, 10))
    sns.scatterplot(x="tsne0",
                    y="tsne1",
                    hue=node_list,
                    palette=sns.color_palette("hls", len(node_list)),
                    data=tsne_df,
                    legend=False,
                    alpha=0.3)
    for i, txt in enumerate(node_list):
        plt.annotate(txt, (tsne_df['tsne0'][i], tsne_df['tsne1'][i]))
    plt.savefig(log_dir() + 't-SNE_plot_node_embeddings.png')
    plt.show()

    if FLAGS.save_embeddings:
        sess.run(val_adj_info.op)

        save_val_embeddings(sess, model, minibatch, FLAGS.validate_batch_size,
                            log_dir())

        if FLAGS.model == "n2v":
            # stopping the gradient for the already trained nodes
            train_ids = tf.constant(
                [[id_map[n]] for n in G.nodes_iter()
                 if not G.node[n]['val'] and not G.node[n]['test']],
                dtype=tf.int32)
            test_ids = tf.constant([[id_map[n]] for n in G.nodes_iter()
                                    if G.node[n]['val'] or G.node[n]['test']],
                                   dtype=tf.int32)
            update_nodes = tf.nn.embedding_lookup(model.context_embeds,
                                                  tf.squeeze(test_ids))
            no_update_nodes = tf.nn.embedding_lookup(model.context_embeds,
                                                     tf.squeeze(train_ids))
            update_nodes = tf.scatter_nd(test_ids, update_nodes,
                                         tf.shape(model.context_embeds))
            no_update_nodes = tf.stop_gradient(
                tf.scatter_nd(train_ids, no_update_nodes,
                              tf.shape(model.context_embeds)))
            model.context_embeds = update_nodes + no_update_nodes
            sess.run(model.context_embeds)

            # run random walks
            from graphsage.utils import run_random_walks
            nodes = [
                n for n in G.nodes_iter()
                if G.node[n]["val"] or G.node[n]["test"]
            ]
            start_time = time.time()
            pairs = run_random_walks(G, nodes, num_walks=50)
            walk_time = time.time() - start_time

            test_minibatch = EdgeMinibatchIterator(
                G,
                id_map,
                placeholders,
                batch_size=FLAGS.batch_size,
                max_degree=FLAGS.max_degree,
                num_neg_samples=FLAGS.neg_sample_size,
                context_pairs=pairs,
                n2v_retrain=True,
                fixed_n2v=True)

            start_time = time.time()
            print("Doing test training for n2v.")
            test_steps = 0
            for epoch in range(FLAGS.n2v_test_epochs):
                test_minibatch.shuffle()
                while not test_minibatch.end():
                    feed_dict = test_minibatch.next_minibatch_feed_dict(
                        layer_infos)
                    feed_dict.update({placeholders['dropout']: FLAGS.dropout})
                    outs = sess.run([
                        model.opt_op, model.loss, model.ranks, model.aff_all,
                        model.mrr, model.outputs1
                    ],
                                    feed_dict=feed_dict)
                    if test_steps % FLAGS.print_every == 0:
                        print("Iter:", '%04d' % test_steps, "train_loss=",
                              "{:.5f}".format(outs[1]), "train_mrr=",
                              "{:.5f}".format(outs[-2]))
                    test_steps += 1
            train_time = time.time() - start_time
            save_val_embeddings(sess,
                                model,
                                minibatch,
                                FLAGS.validate_batch_size,
                                log_dir(),
                                mod="-test")
            print("Total time: ", train_time + walk_time)
            print("Walk time: ", walk_time)
            print("Train time: ", train_time)