예제 #1
0
def train(train_data, test_data=None):
    G = train_data[0]
    features = train_data[1]
    id_map = train_data[2]

    if not features is None:
        # pad with dummy zero vector
        features = np.vstack([features, np.zeros((features.shape[1], ))])

    context_pairs = train_data[3] if FLAGS.random_context else None
    placeholders = construct_placeholders()
    minibatch = EdgeMinibatchIterator(G,
                                      id_map,
                                      placeholders,
                                      batch_size=FLAGS.batch_size,
                                      max_degree=FLAGS.max_degree,
                                      num_neg_samples=FLAGS.neg_sample_size,
                                      context_pairs=context_pairs)
    adj_info_ph = tf.placeholder(tf.int32, shape=minibatch.adj.shape)
    adj_info = tf.Variable(adj_info_ph, trainable=False, name="adj_info")

    if FLAGS.model == 'graphsage_mean':
        # Create model
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   logging=True)
    elif FLAGS.model == 'gcn':
        # Create model
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, 2 * FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, 2 * FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   aggregator_type="gcn",
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   concat=False,
                                   logging=True)

    elif FLAGS.model == 'graphsage_seq':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   identity_dim=FLAGS.identity_dim,
                                   aggregator_type="seq",
                                   model_size=FLAGS.model_size,
                                   logging=True)

    elif FLAGS.model == 'graphsage_maxpool':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   aggregator_type="maxpool",
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   logging=True)
    elif FLAGS.model == 'graphsage_meanpool':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   aggregator_type="meanpool",
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   logging=True)

    elif FLAGS.model == 'n2v':
        model = Node2VecModel(
            placeholders,
            features.shape[0],
            minibatch.deg,
            #2x because graphsage uses concat
            nodevec_dim=2 * FLAGS.dim_1,
            lr=FLAGS.learning_rate)
    else:
        raise Exception('Error: model name unrecognized.')

    config = tf.ConfigProto(log_device_placement=FLAGS.log_device_placement)
    config.gpu_options.allow_growth = True
    #config.gpu_options.per_process_gpu_memory_fraction = GPU_MEM_FRACTION
    config.allow_soft_placement = True

    # default to saving all variables : added by wy
    saver = tf.train.Saver()

    # Initialize session
    sess = tf.Session(config=config)
    merged = tf.summary.merge_all()
    summary_writer = tf.summary.FileWriter(log_dir(), sess.graph)

    # Init variables
    sess.run(tf.global_variables_initializer(),
             feed_dict={adj_info_ph: minibatch.adj})

    # Train model
    if FLAGS.isTrain:
        print("== Training Model ==")
        train_shadow_mrr = None
        shadow_mrr = None

        total_steps = 0
        avg_time = 0.0
        epoch_val_costs = []

        train_adj_info = tf.assign(adj_info, minibatch.adj)
        val_adj_info = tf.assign(adj_info, minibatch.test_adj)
        for epoch in range(FLAGS.epochs):
            minibatch.shuffle()

            iter = 0
            print('Epoch: %04d' % (epoch + 1))
            epoch_val_costs.append(0)
            while not minibatch.end():
                # Construct feed dictionary
                feed_dict = minibatch.next_minibatch_feed_dict()
                feed_dict.update({placeholders['dropout']: FLAGS.dropout})

                t = time.time()
                # Training step
                outs = sess.run([
                    merged, model.opt_op, model.loss, model.ranks,
                    model.aff_all, model.mrr, model.outputs1
                ],
                                feed_dict=feed_dict)
                train_cost = outs[2]
                train_mrr = outs[5]
                if train_shadow_mrr is None:
                    train_shadow_mrr = train_mrr  #
                else:
                    train_shadow_mrr -= (1 - 0.99) * (train_shadow_mrr -
                                                      train_mrr)

                if iter % FLAGS.validate_iter == 0:
                    # Validation
                    sess.run(val_adj_info.op)
                    val_cost, ranks, val_mrr, duration = evaluate(
                        sess, model, minibatch, size=FLAGS.validate_batch_size)
                    sess.run(train_adj_info.op)
                    epoch_val_costs[-1] += val_cost
                if shadow_mrr is None:
                    shadow_mrr = val_mrr
                else:
                    shadow_mrr -= (1 - 0.99) * (shadow_mrr - val_mrr)

                if total_steps % FLAGS.print_every == 0:
                    summary_writer.add_summary(outs[0], total_steps)

                # Print results
                avg_time = (avg_time * total_steps + time.time() -
                            t) / (total_steps + 1)

                if total_steps % FLAGS.print_every == 0:
                    print(
                        "Iter:",
                        '%04d' % iter,
                        "train_loss=",
                        "{:.5f}".format(train_cost),
                        "train_mrr=",
                        "{:.5f}".format(train_mrr),
                        "train_mrr_ema=",
                        "{:.5f}".format(
                            train_shadow_mrr),  # exponential moving average
                        "val_loss=",
                        "{:.5f}".format(val_cost),
                        "val_mrr=",
                        "{:.5f}".format(val_mrr),
                        "val_mrr_ema=",
                        "{:.5f}".format(
                            shadow_mrr),  # exponential moving average
                        "time=",
                        "{:.5f}".format(avg_time))

                    iter += 1
                total_steps += 1

                if total_steps > FLAGS.max_total_steps:
                    break

            if total_steps > FLAGS.max_total_steps:
                break
        print("Optimization Finished!")
        # save model
        saver.save(sess, log_dir() + 'model.ckpt')
        print("save model succeed!")
    else:
        print("==Testing Model==")
        ckpt = tf.train.get_checkpoint_state(log_dir())
        print(log_dir())
        print(ckpt)
        print(ckpt.model_ckeckpoint_path)
        raw_input()
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_ckeckpoint_path)
        else:
            pass

    raw_input()
    if FLAGS.save_embeddings:
        sess.run(val_adj_info.op)

        save_val_embeddings(sess, model, minibatch, FLAGS.validate_batch_size,
                            log_dir())

        if FLAGS.model == "n2v":
            # stopping the gradient for the already trained nodes
            train_ids = tf.constant(
                [[id_map[n]] for n in G.nodes_iter()
                 if not G.node[n]['val'] and not G.node[n]['test']],
                dtype=tf.int32)
            test_ids = tf.constant([[id_map[n]] for n in G.nodes_iter()
                                    if G.node[n]['val'] or G.node[n]['test']],
                                   dtype=tf.int32)
            update_nodes = tf.nn.embedding_lookup(model.context_embeds,
                                                  tf.squeeze(test_ids))
            no_update_nodes = tf.nn.embedding_lookup(model.context_embeds,
                                                     tf.squeeze(train_ids))
            update_nodes = tf.scatter_nd(test_ids, update_nodes,
                                         tf.shape(model.context_embeds))
            no_update_nodes = tf.stop_gradient(
                tf.scatter_nd(train_ids, no_update_nodes,
                              tf.shape(model.context_embeds)))
            model.context_embeds = update_nodes + no_update_nodes
            sess.run(model.context_embeds)

            # run random walks
            from graphsage.utils import run_random_walks
            nodes = [
                n for n in G.nodes_iter()
                if G.node[n]["val"] or G.node[n]["test"]
            ]
            start_time = time.time()
            pairs = run_random_walks(G, nodes, num_walks=50)
            walk_time = time.time() - start_time

            test_minibatch = EdgeMinibatchIterator(
                G,
                id_map,
                placeholders,
                batch_size=FLAGS.batch_size,
                max_degree=FLAGS.max_degree,
                num_neg_samples=FLAGS.neg_sample_size,
                context_pairs=pairs,
                n2v_retrain=True,
                fixed_n2v=True)

            start_time = time.time()
            print("Doing test training for n2v.")
            test_steps = 0
            for epoch in range(FLAGS.n2v_test_epochs):
                test_minibatch.shuffle()
                while not test_minibatch.end():
                    feed_dict = test_minibatch.next_minibatch_feed_dict()
                    feed_dict.update({placeholders['dropout']: FLAGS.dropout})
                    outs = sess.run([
                        model.opt_op, model.loss, model.ranks, model.aff_all,
                        model.mrr, model.outputs1
                    ],
                                    feed_dict=feed_dict)
                    if test_steps % FLAGS.print_every == 0:
                        print("Iter:", '%04d' % test_steps, "train_loss=",
                              "{:.5f}".format(outs[1]), "train_mrr=",
                              "{:.5f}".format(outs[-2]))
                    test_steps += 1
            train_time = time.time() - start_time
            save_val_embeddings(sess,
                                model,
                                minibatch,
                                FLAGS.validate_batch_size,
                                log_dir(),
                                mod="-test")
            print("Total time: ", train_time + walk_time)
            print("Walk time: ", walk_time)
            print("Train time: ", train_time)
예제 #2
0
def train(train_data, test_data, bs=None):
    G = train_data
    G_test = test_data

    if not G.features is None:
        # pad with dummy zero vector
        G.features = np.vstack([G.features, np.zeros((G.features.shape[1], ))])
    features = G.features
    # context_pairs = train_data[3] if FLAGS.random_context else None

    placeholders = construct_placeholders()
    minibatch = EdgeMinibatchIterator(
        G,
        G_test,
        bs,
        # id_map,
        placeholders,
        batch_size=FLAGS.batch_size,
        max_degree=FLAGS.max_degree,
        num_neg_samples=FLAGS.neg_sample_size,
        # context_pairs = context_pairs
    )
    adj_info_ph = tf.placeholder(tf.int32, shape=minibatch.adj.shape)
    adj_info = tf.Variable(adj_info_ph, trainable=False, name="adj_info")

    if FLAGS.model == 'graphsage_mean':
        # Create model
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   logging=True)
    elif FLAGS.model == 'gcn':
        # Create model
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, 2 * FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, 2 * FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   aggregator_type="gcn",
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   concat=False,
                                   logging=True)

    elif FLAGS.model == 'graphsage_seq':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   identity_dim=FLAGS.identity_dim,
                                   aggregator_type="seq",
                                   model_size=FLAGS.model_size,
                                   logging=True)

    elif FLAGS.model == 'graphsage_maxpool':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   aggregator_type="maxpool",
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   logging=True)
    elif FLAGS.model == 'graphsage_meanpool':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   aggregator_type="meanpool",
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   logging=True)

    elif FLAGS.model == 'n2v':
        model = Node2VecModel(
            placeholders,
            features.shape[0],
            minibatch.deg,
            #2x because graphsage uses concat
            nodevec_dim=2 * FLAGS.dim_1,
            lr=FLAGS.learning_rate)
    else:
        raise Exception('Error: model name unrecognized.')

    config = tf.ConfigProto(log_device_placement=FLAGS.log_device_placement)
    config.gpu_options.allow_growth = True
    #config.gpu_options.per_process_gpu_memory_fraction = GPU_MEM_FRACTION
    config.allow_soft_placement = True

    # Initialize session
    sess = tf.Session(config=config)
    merged = tf.summary.merge_all()
    summary_writer = tf.summary.FileWriter(log_dir(), sess.graph)

    # Init variables
    sess.run(tf.global_variables_initializer(),
             feed_dict={adj_info_ph: minibatch.adj})

    # Train model

    train_shadow_mrr = None
    shadow_mrr = None

    total_steps = 0
    avg_time = 0.0
    epoch_val_costs = []

    train_adj_info = tf.assign(adj_info, minibatch.adj)
    val_adj_info = tf.assign(adj_info, minibatch.test_adj)
    for epoch in range(FLAGS.epochs):
        # minibatch.shuffle()

        iter = 0
        print('Epoch: %04d' % (epoch + 1))
        epoch_val_costs.append(0)
        while not minibatch.end():
            # Construct feed dictionary
            feed_dict = minibatch.next_minibatch_feed_dict()
            feed_dict.update({placeholders['dropout']: FLAGS.dropout})

            t = time.time()
            # Training step
            outs = sess.run([
                merged, model.opt_op, model.loss, model.ranks, model.aff_all,
                model.mrr, model.outputs1
            ],
                            feed_dict=feed_dict)
            train_cost = outs[2]
            train_mrr = outs[5]
            if train_shadow_mrr is None:
                train_shadow_mrr = train_mrr  #
            else:
                train_shadow_mrr -= (1 - 0.99) * (train_shadow_mrr - train_mrr)

            if iter % FLAGS.validate_iter == 0:
                # Validation
                sess.run(val_adj_info.op)
                val_cost, ranks, val_mrr, duration = evaluate(
                    sess, model, minibatch, size=FLAGS.validate_batch_size)
                sess.run(train_adj_info.op)
                epoch_val_costs[-1] += val_cost
            if shadow_mrr is None:
                shadow_mrr = val_mrr
            else:
                shadow_mrr -= (1 - 0.99) * (shadow_mrr - val_mrr)

            if total_steps % FLAGS.print_every == 0:
                summary_writer.add_summary(outs[0], total_steps)

            # Print results
            avg_time = (avg_time * total_steps + time.time() -
                        t) / (total_steps + 1)

            if total_steps % FLAGS.print_every == 0:
                print(
                    "Iter:",
                    '%04d' % iter,
                    "train_loss=",
                    "{:.5f}".format(train_cost),
                    "train_mrr=",
                    "{:.5f}".format(train_mrr),
                    "train_mrr_ema=",
                    "{:.5f}".format(
                        train_shadow_mrr),  # exponential moving average
                    "val_loss=",
                    "{:.5f}".format(val_cost),
                    "val_mrr=",
                    "{:.5f}".format(val_mrr),
                    "val_mrr_ema=",
                    "{:.5f}".format(shadow_mrr),  # exponential moving average
                    "time=",
                    "{:.5f}".format(avg_time))

            iter += 1
            total_steps += 1

            if total_steps > FLAGS.max_total_steps:
                break

        if total_steps > FLAGS.max_total_steps:
            break

    print("Optimization Finished!")
예제 #3
0
def train(train_data, test_data=None):
    G = train_data[0]
    features = train_data[1]
    id_map = train_data[2]

    if not features is None:
        # pad with dummy zero vector
        features = np.vstack([features, np.zeros((features.shape[1], ))])

    context_pairs = train_data[3] if FLAGS.random_context else None
    placeholders = construct_placeholders()
    minibatch = EdgeMinibatchIterator(G,
                                      id_map,
                                      placeholders,
                                      batch_size=FLAGS.batch_size,
                                      max_degree=FLAGS.max_degree,
                                      num_neg_samples=FLAGS.neg_sample_size,
                                      context_pairs=context_pairs)
    adj_info_ph = tf.compat.v1.placeholder(tf.int32, shape=minibatch.adj.shape)
    adj_info = tf.Variable(adj_info_ph, trainable=False, name="adj_info")

    if FLAGS.model == 'graphsage_mean':
        # Create model
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   logging=True)
    elif FLAGS.model == 'gcn':
        # Create model
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, 2 * FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, 2 * FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   aggregator_type="gcn",
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   concat=False,
                                   logging=True)

    elif FLAGS.model == 'graphsage_seq':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   identity_dim=FLAGS.identity_dim,
                                   aggregator_type="seq",
                                   model_size=FLAGS.model_size,
                                   logging=True)

    elif FLAGS.model == 'graphsage_maxpool':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   aggregator_type="maxpool",
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   logging=True)
    elif FLAGS.model == 'graphsage_meanpool':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   aggregator_type="meanpool",
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   logging=True)

    elif FLAGS.model == 'n2v':
        model = Node2VecModel(
            placeholders,
            features.shape[0],
            minibatch.deg,
            #2x because graphsage uses concat
            nodevec_dim=2 * FLAGS.dim_1,
            lr=FLAGS.learning_rate)
    else:
        raise Exception('Error: model name unrecognized.')

    config = tf.compat.v1.ConfigProto(
        log_device_placement=FLAGS.log_device_placement)
    config.gpu_options.allow_growth = True
    #config.gpu_options.per_process_gpu_memory_fraction = GPU_MEM_FRACTION
    config.allow_soft_placement = True

    # Initialize WandB experiment
    wandb.init(project='GraphSAGE_trial',
               entity='cvl-liu-01',
               save_code=True,
               tags=['unsupervised'])
    wandb.config.update(flags.FLAGS)

    # Initialize session
    sess = tf.compat.v1.Session(config=config)
    merged = tf.compat.v1.summary.merge_all()
    summary_writer = tf.compat.v1.summary.FileWriter(log_dir(), sess.graph)

    # Init variables
    sess.run(tf.compat.v1.global_variables_initializer(),
             feed_dict={adj_info_ph: minibatch.adj})

    # Init saver
    saver = tf.compat.v1.train.Saver(max_to_keep=8,
                                     keep_checkpoint_every_n_hours=1)

    # Train model
    train_shadow_mrr = None
    val_shadow_mrr = None

    total_steps = 0
    avg_time = 0.0
    epoch_val_costs = []

    train_adj_info = tf.compat.v1.assign(adj_info, minibatch.adj)
    val_adj_info = tf.compat.v1.assign(adj_info, minibatch.test_adj)
    for epoch in range(FLAGS.epochs):
        minibatch.shuffle()

        iter = 0
        print('Epoch: %04d' % (epoch + 1))
        epoch_val_costs.append(0)
        while not minibatch.end():
            # Construct feed dictionary
            feed_dict = minibatch.next_minibatch_feed_dict()
            feed_dict.update({placeholders['dropout']: FLAGS.dropout})

            t = time.time()
            # Training step
            outs = sess.run([
                merged, model.opt_op, model.loss, model.ranks, model.aff_all,
                model.mrr, model.outputs1
            ],
                            feed_dict=feed_dict)
            train_cost = outs[2]
            train_mrr = outs[5]
            if train_shadow_mrr is None:
                train_shadow_mrr = train_mrr  #
            else:
                train_shadow_mrr -= (1 - 0.99) * (train_shadow_mrr - train_mrr)

            # Validation
            if iter % FLAGS.validate_iter == 0:
                sess.run(val_adj_info.op)
                val_cost, ranks, val_mrr, duration = evaluate(
                    sess, model, minibatch, size=FLAGS.validate_batch_size)
                sess.run(train_adj_info.op)
                epoch_val_costs[-1] += val_cost
            if val_shadow_mrr is None:
                val_shadow_mrr = val_mrr
            else:
                val_shadow_mrr -= (1 - 0.99) * (val_shadow_mrr - val_mrr)

            if total_steps % FLAGS.print_every == 0:
                summary_writer.add_summary(outs[0], total_steps)

            # Print results
            avg_time = (avg_time * total_steps + time.time() -
                        t) / (total_steps + 1)

            if total_steps % FLAGS.print_every == 0:
                print(
                    "[%03d/%03d]" % (epoch + 1, FLAGS.epochs),
                    "Iter:",
                    '[%05d/%05d]' % (iter, minibatch.num_training_batches()),
                    "train_loss =",
                    "{:.5f}".format(train_cost),
                    "train_mrr =",
                    "{:.5f}".format(train_mrr),
                    "train_mrr_ema =",
                    "{:.5f}".format(
                        train_shadow_mrr),  # exponential moving average
                    "val_loss =",
                    "{:.5f}".format(val_cost),
                    "val_mrr =",
                    "{:.5f}".format(val_mrr),
                    "val_mrr_ema =",
                    "{:.5f}".format(
                        val_shadow_mrr),  # exponential moving average
                    "time =",
                    "{:.5f}".format(avg_time))

            # W&B Logging
            if FLAGS.wandb_log and iter % FLAGS.wandb_log_iter == 0:
                wandb.log({'train_loss': train_cost, 'epoch': epoch})
                wandb.log({'train_mrr': train_mrr, 'epoch': epoch})
                wandb.log({'train_mrr_ema': train_shadow_mrr, 'epoch': epoch})
                wandb.log({'val_loss': val_cost, 'epoch': epoch})
                wandb.log({'val_mrr': val_mrr, 'epoch': epoch})
                wandb.log({'val_mrr_ema': val_shadow_mrr, 'epoch': epoch})
                wandb.log({'time': avg_time, 'epoch': epoch})

            iter += 1
            total_steps += 1

            if total_steps > FLAGS.max_total_steps:
                print('Max total steps reached!')
                break

        # Save embeddings
        if FLAGS.save_embeddings and epoch % FLAGS.save_embeddings_epoch == 0:
            save_val_embeddings(sess, model, minibatch,
                                FLAGS.validate_batch_size, log_dir())

            # Also report classifier metric on the embedding
            all_tr_res, all_ts_res = osm_classif_eval.evaluate(
                FLAGS.train_prefix, log_dir(), n_iter=FLAGS.classif_n_iter)
            if FLAGS.wandb_log:
                wandb.log(all_tr_res)
                wandb.log(all_ts_res)

        # Save Model checkpoints
        if FLAGS.save_checkpoints and epoch % FLAGS.save_checkpoints_epoch == 0:
            # saver.save(sess, log_dir() + 'model', global_step=1000)
            print('Save model checkpoint:', wandb.run.dir, iter, total_steps,
                  epoch)
            saver.save(
                sess,
                os.path.join(wandb.run.dir,
                             "model-" + str(epoch + 1) + ".ckpt"))

        if total_steps > FLAGS.max_total_steps:
            print('Max total steps reached!')
            break

    print("Optimization Finished!")
    if FLAGS.save_embeddings:
        sess.run(val_adj_info.op)

        save_val_embeddings(sess, model, minibatch, FLAGS.validate_batch_size,
                            log_dir())

        if FLAGS.model == "n2v":
            # stopping the gradient for the already trained nodes
            train_ids = tf.constant(
                [[id_map[n]] for n in G.nodes_iter()
                 if not G.node[n]['val'] and not G.node[n]['test']],
                dtype=tf.int32)
            test_ids = tf.constant([[id_map[n]] for n in G.nodes_iter()
                                    if G.node[n]['val'] or G.node[n]['test']],
                                   dtype=tf.int32)
            update_nodes = tf.nn.embedding_lookup(model.context_embeds,
                                                  tf.squeeze(test_ids))
            no_update_nodes = tf.nn.embedding_lookup(model.context_embeds,
                                                     tf.squeeze(train_ids))
            update_nodes = tf.scatter_nd(test_ids, update_nodes,
                                         tf.shape(model.context_embeds))
            no_update_nodes = tf.stop_gradient(
                tf.scatter_nd(train_ids, no_update_nodes,
                              tf.shape(model.context_embeds)))
            model.context_embeds = update_nodes + no_update_nodes
            sess.run(model.context_embeds)

            # run random walks
            from graphsage.utils import run_random_walks
            nodes = [
                n for n in G.nodes_iter()
                if G.node[n]["val"] or G.node[n]["test"]
            ]
            start_time = time.time()
            pairs = run_random_walks(G, nodes, num_walks=50)
            walk_time = time.time() - start_time

            test_minibatch = EdgeMinibatchIterator(
                G,
                id_map,
                placeholders,
                batch_size=FLAGS.batch_size,
                max_degree=FLAGS.max_degree,
                num_neg_samples=FLAGS.neg_sample_size,
                context_pairs=pairs,
                n2v_retrain=True,
                fixed_n2v=True)

            start_time = time.time()
            print("Doing test training for n2v.")
            test_steps = 0
            for epoch in range(FLAGS.n2v_test_epochs):
                test_minibatch.shuffle()
                while not test_minibatch.end():
                    feed_dict = test_minibatch.next_minibatch_feed_dict()
                    feed_dict.update({placeholders['dropout']: FLAGS.dropout})
                    outs = sess.run([
                        model.opt_op, model.loss, model.ranks, model.aff_all,
                        model.mrr, model.outputs1
                    ],
                                    feed_dict=feed_dict)
                    if test_steps % FLAGS.print_every == 0:
                        print("Iter:", '%04d' % test_steps, "train_loss=",
                              "{:.5f}".format(outs[1]), "train_mrr=",
                              "{:.5f}".format(outs[-2]))
                    test_steps += 1
            train_time = time.time() - start_time
            save_val_embeddings(sess,
                                model,
                                minibatch,
                                FLAGS.validate_batch_size,
                                log_dir(),
                                mod="-test")
            print("Total time: ", train_time + walk_time)
            print("Walk time: ", walk_time)
            print("Train time: ", train_time)
예제 #4
0
def train(train_data, test_data=None, sampler_name='Uniform'):
    
    G = train_data[0]
    features = train_data[1]
    id_map = train_data[2]

    if not features is None:
        # pad with dummy zero vector
        features = np.vstack([features, np.zeros((features.shape[1],))])

    context_pairs = train_data[3] if FLAGS.random_context else None
    placeholders = construct_placeholders()
    minibatch = EdgeMinibatchIterator(G, 
            id_map,
            placeholders, batch_size=FLAGS.batch_size,
            max_degree=FLAGS.max_degree, 
            num_neg_samples=FLAGS.neg_sample_size,
            context_pairs = context_pairs)
    
    adj_info_ph = tf.placeholder(tf.int32, shape=minibatch.adj.shape)
    adj_info = tf.Variable(adj_info_ph, trainable=False, name="adj_info")
    adj_shape = adj_info.get_shape().as_list()

    if FLAGS.model == 'mean_concat':
        # Create model

        if sampler_name == 'Uniform':
            sampler = UniformNeighborSampler(adj_info)
        elif sampler_name == 'ML':
            sampler = MLNeighborSampler(adj_info, features)
        elif sampler_name == 'FastML':
            sampler = FastMLNeighborSampler(adj_info, features)


        #sampler = UniformNeighborSampler(adj_info)
        layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)]

        model = SampleAndAggregate(placeholders, 
                                     features,
                                     adj_info,
                                     minibatch.deg,
                                     concat=True,
                                     layer_infos=layer_infos, 
                                     model_size=FLAGS.model_size,
                                     identity_dim = FLAGS.identity_dim,
                                     logging=True)
    
    elif FLAGS.model == 'mean_add':
        # Create model

        if sampler_name == 'Uniform':
            sampler = UniformNeighborSampler(adj_info)
        elif sampler_name == 'ML':
            sampler = MLNeighborSampler(adj_info, features)
        elif sampler_name == 'FastML':
            sampler = FastMLNeighborSampler(adj_info, features)


        #sampler = UniformNeighborSampler(adj_info)
        layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)]

        model = SampleAndAggregate(placeholders, 
                                     features,
                                     adj_info,
                                     minibatch.deg,
                                     concat=False,
                                     layer_infos=layer_infos, 
                                     model_size=FLAGS.model_size,
                                     identity_dim = FLAGS.identity_dim,
                                     logging=True)

    elif FLAGS.model == 'gcn':

        if sampler_name == 'Uniform':
            sampler = UniformNeighborSampler(adj_info)
        elif sampler_name == 'ML':
            sampler = MLNeighborSampler(adj_info, features)
        elif sampler_name == 'FastML':
            sampler = FastMLNeighborSampler(adj_info, features)

           
        # Create model
        #sampler = UniformNeighborSampler(adj_info)
        layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, 2*FLAGS.dim_1),
                            SAGEInfo("node", sampler, FLAGS.samples_2, 2*FLAGS.dim_2)]

        model = SampleAndAggregate(placeholders, 
                                     features,
                                     adj_info,
                                     minibatch.deg,
                                     layer_infos=layer_infos, 
                                     aggregator_type="gcn",
                                     model_size=FLAGS.model_size,
                                     identity_dim = FLAGS.identity_dim,
                                     concat=False,
                                     logging=True)

    elif FLAGS.model == 'graphsage_seq':
        
        if sampler_name == 'Uniform':
            sampler = UniformNeighborSampler(adj_info)
        elif sampler_name == 'ML':
            sampler = MLNeighborSampler(adj_info, features)
        elif sampler_name == 'FastML':
            sampler = FastMLNeighborSampler(adj_info, features)


        #sampler = UniformNeighborSampler(adj_info)
        layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)]

        model = SampleAndAggregate(placeholders, 
                                     features,
                                     adj_info,
                                     minibatch.deg,
                                     layer_infos=layer_infos, 
                                     identity_dim = FLAGS.identity_dim,
                                     aggregator_type="seq",
                                     model_size=FLAGS.model_size,
                                     logging=True)

    elif FLAGS.model == 'graphsage_maxpool':
        
        if sampler_name == 'Uniform':
            sampler = UniformNeighborSampler(adj_info)
        elif sampler_name == 'ML':
            sampler = MLNeighborSampler(adj_info, features)
        elif sampler_name == 'FastML':
            sampler = FastMLNeighborSampler(adj_info, features)
    
            
        #sampler = UniformNeighborSampler(adj_info)
        layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)]

        model = SampleAndAggregate(placeholders, 
                                    features,
                                    adj_info,
                                    minibatch.deg,
                                     layer_infos=layer_infos, 
                                     aggregator_type="maxpool",
                                     model_size=FLAGS.model_size,
                                     identity_dim = FLAGS.identity_dim,
                                     logging=True)
    elif FLAGS.model == 'graphsage_meanpool':
        
        if sampler_name == 'Uniform':
            sampler = UniformNeighborSampler(adj_info)
        elif sampler_name == 'ML':
            sampler = MLNeighborSampler(adj_info, features)
        elif sampler_name == 'FastML':
            sampler = FastMLNeighborSampler(adj_info, features)
            
        #sampler = UniformNeighborSampler(adj_info)
        layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)]

        model = SampleAndAggregate(placeholders, 
                                    features,
                                    adj_info,
                                    minibatch.deg,
                                     layer_infos=layer_infos, 
                                     aggregator_type="meanpool",
                                     model_size=FLAGS.model_size,
                                     identity_dim = FLAGS.identity_dim,
                                     logging=True)

    elif FLAGS.model == 'n2v':
        model = Node2VecModel(placeholders, features.shape[0],
                                       minibatch.deg,
                                       #2x because graphsage uses concat
                                       nodevec_dim=2*FLAGS.dim_1,
                                       lr=FLAGS.learning_rate)
    else:
        raise Exception('Error: model name unrecognized.')

    config = tf.ConfigProto(log_device_placement=FLAGS.log_device_placement)
    config.gpu_options.allow_growth = True
    #config.gpu_options.per_process_gpu_memory_fraction = GPU_MEM_FRACTION
    config.allow_soft_placement = True
    
    # Initialize session
    sess = tf.Session(config=config)
    merged = tf.summary.merge_all()
    summary_writer = tf.summary.FileWriter(log_dir(sampler_name), sess.graph)
    
    
    # Save model
    saver = tf.train.Saver()
    model_path =  './model/unsup-' + FLAGS.train_prefix.split('/')[-1] + '-' + FLAGS.model_prefix + '-' + sampler_name
    model_path += "/{model:s}_{model_size:s}_{lr:0.4f}/".format(
            model=FLAGS.model,
            model_size=FLAGS.model_size,
            lr=FLAGS.learning_rate)

    if not os.path.exists(model_path):
        os.makedirs(model_path)

    
    # Init variables
    sess.run(tf.global_variables_initializer(), feed_dict={adj_info_ph: minibatch.adj})
    
    
    # Restore params of ML sampler model
    if sampler_name == 'ML' or sampler_name == 'FastML':
        sampler_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="MLsampler")
        #pdb.set_trace() 
        saver_sampler = tf.train.Saver(var_list=sampler_vars)
        sampler_model_path = './model/MLsampler-unsup-' + FLAGS.train_prefix.split('/')[-1] + '-' + FLAGS.model_prefix
        sampler_model_path += "/{model:s}_{model_size:s}_{lr:0.4f}/".format(
            model=FLAGS.model,
            model_size=FLAGS.model_size,
            lr=FLAGS.learning_rate)

        saver_sampler.restore(sess, sampler_model_path + 'model.ckpt')

   
    # Train model
    
    train_shadow_mrr = None
    shadow_mrr = None

    total_steps = 0
    avg_time = 0.0
    epoch_val_costs = []

    train_adj_info = tf.assign(adj_info, minibatch.adj)
    val_adj_info = tf.assign(adj_info, minibatch.test_adj)

    val_cost_ = []
    val_mrr_ = []
    shadow_mrr_ = []
    duration_ = []
    
    ln_acc = sparse.csr_matrix((adj_shape[0], adj_shape[0]), dtype=np.float32)
    lnc_acc = sparse.csr_matrix((adj_shape[0], adj_shape[0]), dtype=np.int32)
    
    ln_acc = ln_acc.tolil()
    lnc_acc = lnc_acc.tolil()

    for epoch in range(FLAGS.epochs): 
        minibatch.shuffle() 

        iter = 0
        print('Epoch: %04d' % (epoch + 1))
        epoch_val_costs.append(0)
        
        while not minibatch.end():
            # Construct feed dictionary
            feed_dict = minibatch.next_minibatch_feed_dict()

            if feed_dict.values()[0] != FLAGS.batch_size:
                break

            feed_dict.update({placeholders['dropout']: FLAGS.dropout})

            t = time.time()
            # Training step
            outs = sess.run([merged, model.opt_op, model.loss, model.ranks, model.aff_all, 
                    model.mrr, model.outputs1, model.loss_node, model.loss_node_count], feed_dict=feed_dict)
            train_cost = outs[2]
            train_mrr = outs[5]
            
            if train_shadow_mrr is None:
                train_shadow_mrr = train_mrr#
            else:
                train_shadow_mrr -= (1-0.99) * (train_shadow_mrr - train_mrr)

            if iter % FLAGS.validate_iter == 0:
                # Validation
                sess.run(val_adj_info.op)
                val_cost, ranks, val_mrr, duration  = evaluate(sess, model, minibatch, size=FLAGS.validate_batch_size)
                sess.run(train_adj_info.op)
                epoch_val_costs[-1] += val_cost
            
            if shadow_mrr is None:
                shadow_mrr = val_mrr
            else:
                shadow_mrr -= (1-0.99) * (shadow_mrr - val_mrr)
            
            val_cost_.append(val_cost)
            val_mrr_.append(val_mrr)
            shadow_mrr_.append(shadow_mrr)
            duration_.append(duration)


            if total_steps % FLAGS.print_every == 0:
                summary_writer.add_summary(outs[0], total_steps)
    
            # Print results
            avg_time = (avg_time * total_steps + time.time() - t) / (total_steps + 1)

            if total_steps % FLAGS.print_every == 0:
                print("Iter:", '%04d' % iter, 
                      "train_loss=", "{:.5f}".format(train_cost),
                      "train_mrr=", "{:.5f}".format(train_mrr), 
                      "train_mrr_ema=", "{:.5f}".format(train_shadow_mrr), # exponential moving average
                      "val_loss=", "{:.5f}".format(val_cost),
                      "val_mrr=", "{:.5f}".format(val_mrr), 
                      "val_mrr_ema=", "{:.5f}".format(shadow_mrr), # exponential moving average
                      "time=", "{:.5f}".format(avg_time))

            
            ln = outs[7].values
            ln_idx = outs[7].indices
            ln_acc[ln_idx[:,0], ln_idx[:,1]] += ln

            lnc = outs[8].values
            lnc_idx = outs[8].indices
            lnc_acc[lnc_idx[:,0], lnc_idx[:,1]] += lnc
                
                
            iter += 1
            total_steps += 1

            if total_steps > FLAGS.max_total_steps:
                break

        if total_steps > FLAGS.max_total_steps:
                break


    print("Validation per epoch in training")
    for ep in range(FLAGS.epochs):
        print("Epoch: %04d"%ep, " val_cost={:.5f}".format(val_cost_[ep]), " val_mrr={:.5f}".format(val_mrr_[ep]), " val_mrr_ema={:.5f}".format(shadow_mrr_[ep]), " duration={:.5f}".format(duration_[ep]))
 
    print("Optimization Finished!")
    
    # Save model
    save_path = saver.save(sess, model_path+'model.ckpt')
    print ('model is saved at %s'%save_path)


    # Save loss node and count
    loss_node_path = './loss_node/unsup-' + FLAGS.train_prefix.split('/')[-1] + '-' + FLAGS.model_prefix + '-' + sampler_name
    loss_node_path += "/{model:s}_{model_size:s}_{lr:0.4f}/".format(
            model=FLAGS.model,
            model_size=FLAGS.model_size,
            lr=FLAGS.learning_rate)
    if not os.path.exists(loss_node_path):
        os.makedirs(loss_node_path)

    loss_node = sparse.save_npz(loss_node_path + 'loss_node.npz', sparse.csr_matrix(ln_acc))
    loss_node_count = sparse.save_npz(loss_node_path + 'loss_node_count.npz', sparse.csr_matrix(lnc_acc))
    print ('loss and count per node is saved at %s'%loss_node_path)    
    
    
    
    if FLAGS.save_embeddings:
        sess.run(val_adj_info.op)

        save_val_embeddings(sess, model, minibatch, FLAGS.validate_batch_size, log_dir(sampler_name))

        if FLAGS.model == "n2v":
            # stopping the gradient for the already trained nodes
            train_ids = tf.constant([[id_map[n]] for n in G.nodes_iter() if not G.node[n]['val'] and not G.node[n]['test']],
                    dtype=tf.int32)
            test_ids = tf.constant([[id_map[n]] for n in G.nodes_iter() if G.node[n]['val'] or G.node[n]['test']], 
                    dtype=tf.int32)
            update_nodes = tf.nn.embedding_lookup(model.context_embeds, tf.squeeze(test_ids))
            no_update_nodes = tf.nn.embedding_lookup(model.context_embeds,tf.squeeze(train_ids))
            update_nodes = tf.scatter_nd(test_ids, update_nodes, tf.shape(model.context_embeds))
            no_update_nodes = tf.stop_gradient(tf.scatter_nd(train_ids, no_update_nodes, tf.shape(model.context_embeds)))
            model.context_embeds = update_nodes + no_update_nodes
            sess.run(model.context_embeds)

            # run random walks
            from graphsage.utils import run_random_walks
            nodes = [n for n in G.nodes_iter() if G.node[n]["val"] or G.node[n]["test"]]
            start_time = time.time()
            pairs = run_random_walks(G, nodes, num_walks=50)
            walk_time = time.time() - start_time

            test_minibatch = EdgeMinibatchIterator(G, 
                id_map,
                placeholders, batch_size=FLAGS.batch_size,
                max_degree=FLAGS.max_degree, 
                num_neg_samples=FLAGS.neg_sample_size,
                context_pairs = pairs,
                n2v_retrain=True,
                fixed_n2v=True)
            
            start_time = time.time()
            print("Doing test training for n2v.")
            test_steps = 0
            for epoch in range(FLAGS.n2v_test_epochs):
                test_minibatch.shuffle()
                while not test_minibatch.end():
                    feed_dict = test_minibatch.next_minibatch_feed_dict()
                    feed_dict.update({placeholders['dropout']: FLAGS.dropout})
                    outs = sess.run([model.opt_op, model.loss, model.ranks, model.aff_all, 
                        model.mrr, model.outputs1], feed_dict=feed_dict)
                    if test_steps % FLAGS.print_every == 0:
                        print("Iter:", '%04d' % test_steps, 
                              "train_loss=", "{:.5f}".format(outs[1]),
                              "train_mrr=", "{:.5f}".format(outs[-2]))
                    test_steps += 1
            train_time = time.time() - start_time
            save_val_embeddings(sess, model, minibatch, FLAGS.validate_batch_size, log_dir(sampler_name), mod="-test")
            print("Total time: ", train_time+walk_time)
            print("Walk time: ", walk_time)
            print("Train time: ", train_time)
예제 #5
0
def train(train_data, test_data=None):
    G = train_data[0]
    features = train_data[1]
    id_map = train_data[2]

    if not features is None:
        # pad with dummy zero vector
        features = np.vstack([features, np.zeros((features.shape[1], ))])

    context_pairs = train_data[3] if FLAGS.random_context else None
    placeholders = construct_placeholders()
    minibatch = EdgeMinibatchIterator(G,
                                      id_map,
                                      placeholders,
                                      batch_size=FLAGS.batch_size,
                                      max_degree=FLAGS.max_degree,
                                      num_neg_samples=FLAGS.neg_sample_size,
                                      context_pairs=context_pairs)
    adj_info_ph = tf.placeholder(tf.int32, shape=minibatch.adj.shape)
    adj_info = tf.Variable(adj_info_ph, trainable=False, name="adj_info")

    if FLAGS.model == 'graphsage_mean':
        # Create model
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   logging=True)
    elif FLAGS.model == 'gcn':
        # Create model
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, 2 * FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, 2 * FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   aggregator_type="gcn",
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   concat=False,
                                   logging=True)

    elif FLAGS.model == 'graphsage_seq':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   identity_dim=FLAGS.identity_dim,
                                   aggregator_type="seq",
                                   model_size=FLAGS.model_size,
                                   logging=True)

    elif FLAGS.model == 'graphsage_maxpool':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   aggregator_type="maxpool",
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   logging=True)
    elif FLAGS.model == 'graphsage_meanpool':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   aggregator_type="meanpool",
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   logging=True)

    elif FLAGS.model == 'n2v':
        model = Node2VecModel(
            placeholders,
            features.shape[0],
            minibatch.deg,
            #2x because graphsage uses concat
            nodevec_dim=2 * FLAGS.dim_1,
            lr=FLAGS.learning_rate)
    else:
        raise Exception('Error: model name unrecognized.')

    config = tf.ConfigProto(log_device_placement=FLAGS.log_device_placement)
    config.gpu_options.allow_growth = True
    #config.gpu_options.per_process_gpu_memory_fraction = GPU_MEM_FRACTION
    config.allow_soft_placement = True

    # Initialize session
    sess = tf.Session(config=config)
    merged = tf.summary.merge_all()
    summary_writer = tf.summary.FileWriter(log_dir(), sess.graph)

    # Init variables
    # sess.run(feed_dict={adj_info_ph: minibatch.adj})

    train_adj_info = tf.assign(adj_info, minibatch.adj)
    val_adj_info = tf.assign(adj_info, minibatch.test_adj)

    saver = tf.train.Saver()
    saver.restore(sess, "supervisedTrainedModel_MC/model.ckpt")

    if FLAGS.save_embeddings:
        sess.run(val_adj_info.op)
        save_val_embeddings(sess, model, minibatch, FLAGS.validate_batch_size,
                            log_dir())
예제 #6
0
def train(train_data,
          log_dir,
          Theta=None,
          fixed_neigh_weights=None,
          test_data=None,
          neg_sample_weights=None):
    G = train_data[0]
    Gneg = Graph_complement(G)
    features = train_data[1]
    id_map = train_data[2]
    Adj_mat = Edges_to_Adjacency_mat(G.edges(), len(G.nodes()))
    #     print('A in unsup: ', Adj_mat)
    #     negAdj_mat = Edges_to_Adjacency_mat(Gneg.edges(), len(Gneg.nodes()))
    FLAGS.batch_size = batch_size_def(len(G.edges()))
    FLAGS.negbatch_size = batch_size_def(len(Gneg.edges()))
    FLAGS.samples_1 = min(25, len(G.nodes()))
    FLAGS.samples_2 = min(10, len(G.nodes()))
    FLAGS.negsamples_1 = min(25, len(Gneg.nodes()))
    FLAGS.negsamples_2 = min(10, len(Gneg.nodes()))
    FLAGS.neg_sample_size = FLAGS.batch_size

    if (len(G.nodes()) <= small_big_threshold):
        FLAGS.model_size = 'small'
    else:
        FLAGS.model_size = 'big'

    if not features is None:
        # pad with dummy zero vector
        features = np.vstack([features, np.zeros((features.shape[1], ))])

    context_pairs = train_data[3] if FLAGS.random_context else None
    placeholders = construct_placeholders()
    #print('placeholders: ', placeholders)
    minibatch = EdgeMinibatchIterator(G,
                                      id_map,
                                      placeholders,
                                      batch_size=FLAGS.batch_size,
                                      max_degree=FLAGS.max_degree,
                                      num_neg_samples=FLAGS.neg_sample_size,
                                      context_pairs=context_pairs)
    aggbatch_size = len(minibatch.agg_batch_Z1)
    negaggbatch_size = len(minibatch.agg_batch_Z3)
    negminibatch = EdgeMinibatchIterator(Gneg,
                                         id_map,
                                         placeholders,
                                         batch_size=FLAGS.negbatch_size,
                                         max_degree=FLAGS.max_degree,
                                         num_neg_samples=FLAGS.neg_sample_size,
                                         context_pairs=context_pairs)

    #adj_info = tf.Variable(placeholders['adj_info_ph'], trainable=False, name="adj_info")
    adj_info_ph = tf.placeholder(tf.int32,
                                 shape=minibatch.adj.shape,
                                 name="adj_info_ph")
    adj_info = tf.Variable(adj_info_ph, trainable=False, name="adj_info")

    if FLAGS.model == 'graphsage_mean':
        # Create model
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1,
                     FLAGS.negsamples_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2,
                     FLAGS.negsamples_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   Adj_mat=Adj_mat,
                                   non_edges=Gneg.edges(),
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   logging=True,
                                   fixed_theta_1=Theta,
                                   fixed_neigh_weights=fixed_neigh_weights,
                                   neg_sample_weights=neg_sample_weights,
                                   aggbatch_size=aggbatch_size,
                                   negaggbatch_size=negaggbatch_size)
    elif FLAGS.model == 'gcn':
        # Create model
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, 2 * FLAGS.dim_1,
                     FLAGS.negsamples_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, 2 * FLAGS.dim_2,
                     FLAGS.negsamples_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   aggregator_type="gcn",
                                   Adj_mat=Adj_mat,
                                   non_edges=Gneg.edges(),
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   concat=False,
                                   logging=True,
                                   fixed_theta_1=Theta,
                                   fixed_neigh_weights=fixed_neigh_weights,
                                   neg_sample_weights=neg_sample_weights,
                                   aggbatch_size=aggbatch_size,
                                   negaggbatch_size=negaggbatch_size)

    elif FLAGS.model == 'graphsage_seq':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1,
                     FLAGS.negsamples_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2,
                     FLAGS.negsamples_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   identity_dim=FLAGS.identity_dim,
                                   aggregator_type="seq",
                                   Adj_mat=Adj_mat,
                                   non_edges=Gneg.edges(),
                                   model_size=FLAGS.model_size,
                                   logging=True,
                                   fixed_theta_1=Theta,
                                   fixed_neigh_weights=fixed_neigh_weights,
                                   neg_sample_weights=neg_sample_weights,
                                   aggbatch_size=aggbatch_size,
                                   negaggbatch_size=negaggbatch_size)

    elif FLAGS.model == 'graphsage_maxpool':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1,
                     FLAGS.negsamples_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2,
                     FLAGS.negsamples_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   aggregator_type="maxpool",
                                   Adj_mat=Adj_mat,
                                   non_edges=Gneg.edges(),
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   logging=True,
                                   fixed_theta_1=Theta,
                                   fixed_neigh_weights=fixed_neigh_weights,
                                   neg_sample_weights=neg_sample_weights,
                                   aggbatch_size=aggbatch_size,
                                   negaggbatch_size=negaggbatch_size)

    elif FLAGS.model == 'graphsage_meanpool':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1,
                     FLAGS.negsamples_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2,
                     FLAGS.negsamples_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   aggregator_type="meanpool",
                                   Adj_mat=Adj_mat,
                                   non_edges=Gneg.edges(),
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   logging=True,
                                   fixed_theta_1=Theta,
                                   fixed_neigh_weights=fixed_neigh_weights,
                                   neg_sample_weights=neg_sample_weights,
                                   aggbatch_size=aggbatch_size,
                                   negaggbatch_size=negaggbatch_size)

    elif FLAGS.model == 'n2v':
        model = Node2VecModel(
            placeholders,
            features.shape[0],
            minibatch.deg,
            #2x because graphsage uses concat
            nodevec_dim=2 * FLAGS.dim_1,
            lr=FLAGS.learning_rate)
    else:
        raise Exception('Error: model name unrecognized.')

    config = tf.ConfigProto(log_device_placement=FLAGS.log_device_placement)
    config.gpu_options.allow_growth = True
    #config.gpu_options.per_process_gpu_memory_fraction = GPU_MEM_FRACTION
    config.allow_soft_placement = True

    # Initialize session
    sess = tf.Session(config=config)
    merged = tf.summary.merge_all()
    #summary_writer = tf.summary.FileWriter(log_dir(), sess.graph)

    # Init variables
    #minibatch.adj = minibatch.adj.astype(np.int32)
    #print('minibatch.adj.shape: %s, dtype: %s' % (minibatch.adj.shape, np.ndarray.dtype(minibatch.adj)))
    sess.run(tf.global_variables_initializer(),
             feed_dict={adj_info_ph: minibatch.adj})

    # Train model

    train_shadow_mrr = None
    shadow_mrr = None

    total_steps = 0
    avg_time = 0.0
    epoch_val_costs = []
    train_adj_info = tf.assign(adj_info, minibatch.adj)
    val_adj_info = tf.assign(adj_info, minibatch.test_adj)

    print_flag = False
    if (Theta is None):
        print_flag = True

    for epoch in range(FLAGS.epochs):

        minibatch.shuffle()
        negminibatch.shuffle()
        iter = 0
        #         print('Epoch: %04d' % (epoch + 1))
        epoch_val_costs.append(0)
        if (FLAGS.model_size == 'big'):
            whichbatch = minibatch
        elif (len(G.edges()) > len(Gneg.edges())):
            whichbatch = minibatch
            opwhichbatch = negminibatch
        else:
            whichbatch = negminibatch
            opwhichbatch = minibatch

        while (not whichbatch.end()):
            if (FLAGS.model_size == 'small' and opwhichbatch.end()):
                opwhichbatch.shuffle()
            # Construct feed dictionary

            feed_dict = minibatch.next_minibatch_feed_dict()
            negfeed_dict = negminibatch.next_minibatch_feed_dict()

            if (True):
                feed_dict.update({
                    placeholders['negbatch1']:
                    negfeed_dict[placeholders['batch1']]
                })
                feed_dict.update({
                    placeholders['negbatch2']:
                    negfeed_dict[placeholders['batch2']]
                })
                feed_dict.update({
                    placeholders['negbatch_size']:
                    negfeed_dict[placeholders['batch_size']]
                })
            else:
                batch1 = feed_dict[placeholders['batch1']]
                feed_dict.update({placeholders['negbatch1']: batch1})
                feed_dict.update({
                    placeholders['negbatch2']:
                    negminibatch.feed_dict_negbatch(batch1)
                })
                feed_dict.update({
                    placeholders['negbatch_size']:
                    negfeed_dict[placeholders['batch_size']]
                })

            if (Theta is not None):
                break
            if (total_steps == 0):
                print_summary(model, feed_dict, sess, Adj_mat, 'first',
                              print_flag)

            t = time.time()
            # Training step
            outs = sess.run(
                [
                    merged, model.opt_op, model.loss, model.ranks,
                    model.aff_all, model.mrr, model.outputs1, model.outputs2,
                    model.negoutputs1, model.negoutputs2
                ],
                feed_dict=feed_dict
            )  #, model.current_similarity   , model.learned_vars['theta_1']  # , model.aggregation
            train_cost = outs[2]
            train_mrr = outs[5]

            #             Z = outs[8]
            #             outs = np.concatenate((outs[6],outs[7]), axis=0)
            #             indices = np.concatenate((feed_dict[placeholders['batch1']],feed_dict[placeholders['batch2']]))
            #             Z = np.zeros((len(G.nodes()),FLAGS.dim_2*2))
            #             for node in G.nodes():
            #                 Z[node,:] = outs[int(np.argwhere(indices==node)[0]),:]#[0:len(G.nodes())] # 8
            #             print('Z shape: ', Z.shape)

            #             if train_shadow_mrr is None:
            #                 train_shadow_mrr = train_mrr #
            #             else:
            #                 train_shadow_mrr -= (1-0.99) * (train_shadow_mrr - train_mrr)
            #
            #             if iter % FLAGS.validate_iter == 0:
            #                 # Validation
            #                 sess.run(val_adj_info.op)
            #                 val_cost, ranks, val_mrr, duration = evaluate(sess, model, minibatch, size=FLAGS.validate_batch_size, negminibatch_iter=negminibatch, placeholders=placeholders)
            #                 sess.run(train_adj_info.op)
            #                 epoch_val_costs[-1] += val_cost
            #
            #             if shadow_mrr is None:
            #                 shadow_mrr = val_mrr
            #             else:
            #                 shadow_mrr -= (1-0.99) * (shadow_mrr - val_mrr)

            #             if total_steps % FLAGS.print_every == 0:
            #                 summary_writer.add_summary(outs[0], total_steps)

            # Print results
            avg_time = (avg_time * total_steps + time.time() -
                        t) / (total_steps + 1)

            if total_steps % FLAGS.print_every == 0:
                print('loss: ', outs[2])
#                 print("Iter:", '%04d' % iter,
#                       "train_loss=", "{:.5f}".format(train_cost),
#                       "train_mrr=", "{:.5f}".format(train_mrr),
#                       "train_mrr_ema=", "{:.5f}".format(train_shadow_mrr), # exponential moving average
#                       "val_loss=", "{:.5f}".format(val_cost),
#                       "val_mrr=", "{:.5f}".format(val_mrr),
#                       "val_mrr_ema=", "{:.5f}".format(shadow_mrr), # exponential moving average
#                       "time=", "{:.5f}".format(avg_time))

#             similarity_weights = outs[7]
#             [new_adj_info, new_batch_edges] = sess.run([model.adj_info, \
#                                                                 model.new_batch_edges], \
#                                                                   feed_dict=feed_dict)
#             minibatch.graph_update(new_adj_info, new_batch_edges)

            iter += 1
            total_steps += 1

            if total_steps > FLAGS.max_total_steps:
                break

        if (Theta is not None):
            break
        if total_steps > FLAGS.max_total_steps:
            break


#     print("SGD Optimization Finished!")

#     feed_dict = dict()
#     feed_dict.update({placeholders['batch_size'] : len(G.nodes())})

#     minibatch = EdgeMinibatchIterator(G,
#             id_map,
#             placeholders, batch_size = len(G.edges()),
#             max_degree = FLAGS.max_degree,
#             num_neg_samples = FLAGS.neg_sample_size,
#             context_pairs = context_pairs)
#     feed_dict = minibatch.next_minibatch_feed_dict()
#     _, Z = sess.run([merged, model.aggregation], feed_dict=feed_dict) #, model.concat   ,     aggregator_cls.vars

#     Z_tilde = np.repeat(Z, [len(G.nodes())], axis=0)
#     Z_tilde_tilde = np.tile(Z, (len(G.nodes()),1))
#     final_theta_1 = Quadratic_SDP_solver (Z_tilde, Z_tilde_tilde, FLAGS.dim_1*2, FLAGS.dim_2*2)
# #     Z_centralized = Z - np.mean(Z, axis=0)
# #     final_adj_matrix = np.abs(np.matmul(Z_centralized, np.matmul(final_theta_1, np.transpose(Z_centralized))))
#     final_adj_matrix = np.matmul(Z, np.matmul(final_theta_1, np.transpose(Z)))
#     feed_dict = minibatch.next_minibatch_feed_dict()
#     feed_dict.update({placeholders['dropout']: FLAGS.dropout})
#     feed_dict.update({placeholders['all_nodes']: all_nodes})

#     FLAGS.batch_size = len(G.edges())
#     FLAGS.negbatch_size = len(Gneg.edges())
#     FLAGS.samples_1 = len(G.nodes())
#     FLAGS.samples_2 = len(G.nodes())
#     FLAGS.negsamples_1 = len(Gneg.nodes())
#     FLAGS.negsamples_2 = len(Gneg.nodes())
#     feed_dict = minibatch.batch_feed_dict(G.edges())
#     negfeed_dict = negminibatch.batch_feed_dict(Gneg.edges())
#     feed_dict.update({placeholders['negbatch1']: negfeed_dict[placeholders['batch1']]})
#     feed_dict.update({placeholders['negbatch2']: negfeed_dict[placeholders['batch2']]})
#     feed_dict.update({placeholders['negbatch_size']: negfeed_dict[placeholders['batch_size']]})
#     if(outs is not None):
#         print('loss: ', outs[2])

    final_adj_matrix, final_theta_1, Z, loss, U = print_summary(
        model, feed_dict, sess, Adj_mat, 'last', print_flag)

    #print('Z shape: ', Z.shape)

    if FLAGS.save_embeddings:
        sess.run(val_adj_info.op)

        save_val_embeddings(sess, model, minibatch, FLAGS.validate_batch_size,
                            log_dir())

        if FLAGS.model == "n2v":
            # stopping the gradient for the already trained nodes
            train_ids = tf.constant(
                [[id_map[n]] for n in G.nodes_iter()
                 if not G.node[n]['val'] and not G.node[n]['test']],
                dtype=tf.int32)
            test_ids = tf.constant([[id_map[n]] for n in G.nodes_iter()
                                    if G.node[n]['val'] or G.node[n]['test']],
                                   dtype=tf.int32)
            update_nodes = tf.nn.embedding_lookup(model.context_embeds,
                                                  tf.squeeze(test_ids))
            no_update_nodes = tf.nn.embedding_lookup(model.context_embeds,
                                                     tf.squeeze(train_ids))
            update_nodes = tf.scatter_nd(test_ids, update_nodes,
                                         tf.shape(model.context_embeds))
            no_update_nodes = tf.stop_gradient(
                tf.scatter_nd(train_ids, no_update_nodes,
                              tf.shape(model.context_embeds)))
            model.context_embeds = update_nodes + no_update_nodes
            sess.run(model.context_embeds)

            # run random walks
            from graphsage.utils import run_random_walks
            nodes = [
                n for n in G.nodes_iter()
                if G.node[n]["val"] or G.node[n]["test"]
            ]
            start_time = time.time()
            pairs = run_random_walks(G, nodes, num_walks=50)
            walk_time = time.time() - start_time

            test_minibatch = EdgeMinibatchIterator(
                G,
                id_map,
                placeholders,
                batch_size=FLAGS.batch_size,
                max_degree=FLAGS.max_degree,
                num_neg_samples=FLAGS.neg_sample_size,
                context_pairs=pairs,
                n2v_retrain=True,
                fixed_n2v=True)

            start_time = time.time()
            print("Doing test training for n2v.")
            test_steps = 0
            for epoch in range(FLAGS.n2v_test_epochs):
                test_minibatch.shuffle()
                while not test_minibatch.end():
                    feed_dict = test_minibatch.next_minibatch_feed_dict()
                    feed_dict.update({placeholders['dropout']: FLAGS.dropout})
                    outs = sess.run([
                        model.opt_op, model.loss, model.ranks, model.aff_all,
                        model.mrr, model.outputs1
                    ],
                                    feed_dict=feed_dict)
                    if test_steps % FLAGS.print_every == 0:
                        print("Iter:", '%04d' % test_steps, "train_loss=",
                              "{:.5f}".format(outs[1]), "train_mrr=",
                              "{:.5f}".format(outs[-2]))
                    test_steps += 1
            train_time = time.time() - start_time
            save_val_embeddings(sess,
                                model,
                                minibatch,
                                FLAGS.validate_batch_size,
                                log_dir(),
                                mod="-test")
            print("Total time: ", train_time + walk_time)
            print("Walk time: ", walk_time)
            print("Train time: ", train_time)
    #del adj_info_ph, adj_info,placeholders
    return final_adj_matrix, G, final_theta_1, Z, loss, U  #, learned_vars
예제 #7
0
def train(train_data,
          minibatch,
          model_name="graphsage_mean",
          profile=False,
          test_data=None):
    G = train_data[0]
    # features = train_data[1]
    id_map = train_data[1]
    '''
        if not features is None:
        # pad with dummy zero vector
        features = np.vstack([features, np.zeros((features.shape[1],))])
    '''

    context_pairs = train_data[2] if FLAGS.random_context else None
    placeholders = construct_placeholders()
    minibatch.set_place_holder(placeholders)
    adj_info_ph = tf.placeholder(tf.int32, shape=minibatch.adj.shape)
    adj_info = tf.Variable(adj_info_ph, trainable=False, name="adj_info")

    if model_name == 'graphsage_mean':
        # Create model
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   fea_dim=FLAGS.feats_dim,
                                   logging=True)
    elif model_name == 'gcn':
        # Create model
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, 2 * FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, 2 * FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   aggregator_type="gcn",
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   concat=False,
                                   fea_dim=FLAGS.feats_dim,
                                   logging=True)

    elif model_name == 'graphsage_seq':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   identity_dim=FLAGS.identity_dim,
                                   aggregator_type="seq",
                                   model_size=FLAGS.model_size,
                                   fea_dim=FLAGS.feats_dim,
                                   logging=True)

    elif model_name == 'graphsage_maxpool':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   aggregator_type="maxpool",
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   fea_dim=FLAGS.feats_dim,
                                   logging=True)
    elif model_name == 'graphsage_meanpool':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   aggregator_type="meanpool",
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   fea_dim=FLAGS.feats_dim,
                                   logging=True)

    elif model_name == 'n2v':
        model = Node2VecModel(
            placeholders,
            602,
            minibatch.deg,
            #2x because graphsage uses concat
            nodevec_dim=2 * FLAGS.dim_1,
            lr=FLAGS.learning_rate)
    else:
        raise Exception('Error: model name unrecognized.')

    config = tf.ConfigProto(log_device_placement=FLAGS.log_device_placement)
    config.gpu_options.allow_growth = True
    #config.gpu_options.per_process_gpu_memory_fraction = GPU_MEM_FRACTION
    config.allow_soft_placement = True

    # Initialize session
    sess = tf.Session(config=config)
    merged = tf.summary.merge_all()
    summary_writer = tf.summary.FileWriter(log_dir(), sess.graph)

    # Init variables
    sess.run(tf.global_variables_initializer(),
             feed_dict={adj_info_ph: minibatch.adj})

    # profile init
    options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
    run_metadata = tf.RunMetadata()
    many_runs_timeline = TimeLiner()

    # Train model

    train_shadow_mrr = None
    shadow_mrr = None

    total_steps = 0
    avg_time = 0.0
    epoch_val_costs = []

    train_adj_info = tf.assign(adj_info, minibatch.adj)
    val_adj_info = tf.assign(adj_info, minibatch.test_adj)
    for epoch in range(FLAGS.epochs):
        minibatch.shuffle()

        iter = 0
        print('Epoch: %04d' % (epoch + 1))
        epoch_val_costs.append(0)
        while not minibatch.end():
            # Construct feed dictionary
            feed_dict = minibatch.next_minibatch_feed_dict()
            feed_dict.update({placeholders['dropout']: FLAGS.dropout})

            t = time.time()
            # Training step
            if profile:
                outs = sess.run([
                    merged, model.opt_op, model.loss, model.ranks,
                    model.aff_all, model.mrr, model.outputs1
                ],
                                feed_dict=feed_dict,
                                options=options,
                                run_metadata=run_metadata)
                fetched_timeline = timeline.Timeline(run_metadata.step_stats)
                chrome_trace = fetched_timeline.generate_chrome_trace_format()
                if total_steps >= 50:
                    many_runs_timeline.update_timeline(chrome_trace)
            else:
                outs = sess.run([
                    merged, model.opt_op, model.loss, model.ranks,
                    model.aff_all, model.mrr, model.outputs1
                ],
                                feed_dict=feed_dict)
            tf.train.write_graph(sess.graph_def,
                                 './export_models',
                                 'graphsage_{0}.pb'.format(model_name),
                                 as_text=False)
            break
            train_cost = outs[2]
            train_mrr = outs[5]

            if train_shadow_mrr is None:
                train_shadow_mrr = train_mrr  #
            else:
                train_shadow_mrr -= (1 - 0.99) * (train_shadow_mrr - train_mrr)

            if iter % FLAGS.validate_iter == 0:
                # Validation
                sess.run(val_adj_info.op)
                val_cost, ranks, val_mrr, duration = evaluate(
                    sess, model, minibatch, size=FLAGS.validate_batch_size)
                sess.run(train_adj_info.op)
                epoch_val_costs[-1] += val_cost
            if shadow_mrr is None:
                shadow_mrr = val_mrr
            else:
                shadow_mrr -= (1 - 0.99) * (shadow_mrr - val_mrr)

            if total_steps % FLAGS.print_every == 0:
                summary_writer.add_summary(outs[0], total_steps)

            # Print results
            if total_steps >= 50:
                avg_time = (avg_time * (total_steps - 50) + time.time() -
                            t) / (total_steps - 50 + 1)

            if total_steps % FLAGS.print_every == 0:
                print(
                    "Iter:",
                    '%04d' % iter,
                    "train_loss=",
                    "{:.5f}".format(train_cost),
                    "train_mrr=",
                    "{:.5f}".format(train_mrr),
                    "train_mrr_ema=",
                    "{:.5f}".format(
                        train_shadow_mrr),  # exponential moving average
                    "val_loss=",
                    "{:.5f}".format(val_cost),
                    "val_mrr=",
                    "{:.5f}".format(val_mrr),
                    "val_mrr_ema=",
                    "{:.5f}".format(shadow_mrr),  # exponential moving average
                    "time=",
                    "{:.5f}".format(avg_time))

            iter += 1
            total_steps += 1

            if total_steps > FLAGS.max_total_steps:
                break

        if total_steps > FLAGS.max_total_steps:
            break

    if profile:
        # write profile data into json format
        file_name = model_name + ".json"
        many_runs_timeline.save(file_name)
        file_name = model_name + "_profile_time.txt"
        with open(file_name, "w") as f:
            f.write("time= {:.5f}s".format(avg_time))
    else:
        file_name = model_name + "_no_profile_time.txt"
        with open(file_name, "w") as f:
            f.write("time= {:.5f}s".format(avg_time))
    print("Optimization Finished!")
def train(train_data, test_data=None):
    G = train_data[0]
    features = train_data[1]
    id_map = train_data[2]
    prob_matrix = train_data[5]

    if not features is None:
        # pad with dummy zero vector
        features = np.vstack([features, np.zeros((features.shape[1], ))])

    context_pairs = train_data[3] if FLAGS.random_context else None
    placeholders = construct_placeholders()
    minibatch = EdgeMinibatchIterator(G,
                                      id_map,
                                      placeholders,
                                      prob_matrix,
                                      batch_size=FLAGS.batch_size,
                                      max_degree=FLAGS.max_degree,
                                      num_neg_samples=FLAGS.neg_sample_size,
                                      context_pairs=context_pairs)
    adj_info_ph = tf.placeholder(tf.int32, shape=minibatch.adj.shape)
    adj_info = tf.Variable(adj_info_ph, trainable=False, name="adj_info")

    if FLAGS.model == 'graphsage_mean':
        # Create model
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   logging=True)
    elif FLAGS.model == 'gcn':
        # Create model
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, 2 * FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, 2 * FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   aggregator_type="gcn",
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   concat=False,
                                   logging=True)

    elif FLAGS.model == 'graphsage_seq':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   identity_dim=FLAGS.identity_dim,
                                   aggregator_type="seq",
                                   model_size=FLAGS.model_size,
                                   logging=True)

    elif FLAGS.model == 'graphsage_maxpool':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   aggregator_type="maxpool",
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   logging=True)
    elif FLAGS.model == 'graphsage_meanpool':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   aggregator_type="meanpool",
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   logging=True)

    elif FLAGS.model == 'n2v':
        model = Node2VecModel(
            placeholders,
            features.shape[0],
            minibatch.deg,
            #2x because graphsage uses concat
            nodevec_dim=2 * FLAGS.dim_1,
            lr=FLAGS.learning_rate)
    else:
        raise Exception('Error: model name unrecognized.')

    config = tf.ConfigProto(log_device_placement=FLAGS.log_device_placement)
    config.gpu_options.allow_growth = True
    #config.gpu_options.per_process_gpu_memory_fraction = GPU_MEM_FRACTION
    config.allow_soft_placement = True

    # Initialize session
    sess = tf.Session(config=config)
    merged = tf.summary.merge_all()
    summary_writer = tf.summary.FileWriter(log_dir(), sess.graph)

    # Init variables
    sess.run(tf.global_variables_initializer(),
             feed_dict={adj_info_ph: minibatch.adj})

    # Train model

    train_shadow_mrr = None
    shadow_mrr = None

    total_steps = 0
    avg_time = 0.0
    epoch_val_costs = []

    train_adj_info = tf.assign(adj_info, minibatch.adj)
    val_adj_info = tf.assign(adj_info, minibatch.test_adj)

    min_train_loss = 0.0
    fp = open(log_dir() + 'epoch_train_loss.txt', 'w')
    loss_list = []
    epoch_list = []

    for epoch in range(FLAGS.epochs):
        minibatch.shuffle()

        epoch_train_loss = 0.0

        iter = 0
        print('Epoch: %04d' % (epoch + 1))
        epoch_val_costs.append(0)
        while not minibatch.end():
            # Construct feed dictionary
            feed_dict = minibatch.next_minibatch_feed_dict(layer_infos)
            feed_dict.update({placeholders['dropout']: FLAGS.dropout})

            t = time.time()
            # Training step
            outs = sess.run([
                merged, model.opt_op, model.loss, model.ranks, model.aff_all,
                model.mrr, model.outputs1
            ],
                            feed_dict=feed_dict)
            train_cost = outs[2]
            train_mrr = outs[5]

            epoch_train_loss += train_cost

            if train_shadow_mrr is None:
                train_shadow_mrr = train_mrr  #
            else:
                train_shadow_mrr -= (1 - 0.99) * (train_shadow_mrr - train_mrr)

            if total_steps % FLAGS.print_every == 0:
                summary_writer.add_summary(outs[0], total_steps)

            # Print results
            avg_time = (avg_time * total_steps + time.time() -
                        t) / (total_steps + 1)

            if total_steps % FLAGS.print_every == 0:
                print(
                    "Iter:",
                    '%04d' % iter,
                    "train_loss=",
                    "{:.5f}".format(train_cost),
                    "train_mrr=",
                    "{:.5f}".format(train_mrr),
                    "train_mrr_ema=",
                    "{:.5f}".format(
                        train_shadow_mrr),  # exponential moving average
                    "time=",
                    "{:.5f}".format(avg_time))

            iter += 1
            total_steps += 1

            if total_steps > FLAGS.max_total_steps:
                break

        if total_steps > FLAGS.max_total_steps:
            break

        epoch_train_loss = epoch_train_loss / iter

        print('epoch: ' + str(epoch) + '\t' + 'train_loss: ' +
              str(epoch_train_loss) + '\n')
        fp.write('epoch: ' + str(epoch) + '\t' + 'train_loss: ' +
                 str(epoch_train_loss) + '\n')
        loss_list.append(epoch_train_loss)
        epoch_list.append(epoch)

        if epoch == 0:
            min_train_loss = epoch_train_loss
            sess.run(val_adj_info.op)
            node_embeddings, node_list = save_val_embeddings(
                sess, model, minibatch, FLAGS.validate_batch_size, log_dir(),
                layer_infos)
        elif epoch_train_loss < min_train_loss:
            min_train_loss = epoch_train_loss
            sess.run(val_adj_info.op)
            node_embeddings, node_list = save_val_embeddings(
                sess, model, minibatch, FLAGS.validate_batch_size, log_dir(),
                layer_infos)

    print("Optimization Finished!")

    fp.close()

    np.save(log_dir() + "val.npy", node_embeddings)
    with open(log_dir() + "val.txt", "w") as fp:
        fp.write("\n".join(map(str, node_list)))

    plt.plot(epoch_list, loss_list)
    plt.savefig(log_dir() + 'loss_trend.png')
    plt.clf()

    ## t-SNE plot
    ## creating t-SNE embedding feature
    #tsne_feat=TSNE(n_components=2).fit_transform(node_embeddings)
    tsne_feat = node_embeddings

    ## plotting t-SNE embedding
    tsne_df = pd.DataFrame(columns=['tsne0', 'tsne1'])
    tsne_df['tsne0'] = tsne_feat[:, 0]
    tsne_df['tsne1'] = tsne_feat[:, 1]
    plt.figure(figsize=(16, 10))
    sns.scatterplot(x="tsne0",
                    y="tsne1",
                    hue=node_list,
                    palette=sns.color_palette("hls", len(node_list)),
                    data=tsne_df,
                    legend=False,
                    alpha=0.3)
    for i, txt in enumerate(node_list):
        plt.annotate(txt, (tsne_df['tsne0'][i], tsne_df['tsne1'][i]))
    plt.savefig(log_dir() + 't-SNE_plot_node_embeddings.png')
    plt.show()

    if FLAGS.save_embeddings:
        sess.run(val_adj_info.op)

        save_val_embeddings(sess, model, minibatch, FLAGS.validate_batch_size,
                            log_dir())

        if FLAGS.model == "n2v":
            # stopping the gradient for the already trained nodes
            train_ids = tf.constant(
                [[id_map[n]] for n in G.nodes_iter()
                 if not G.node[n]['val'] and not G.node[n]['test']],
                dtype=tf.int32)
            test_ids = tf.constant([[id_map[n]] for n in G.nodes_iter()
                                    if G.node[n]['val'] or G.node[n]['test']],
                                   dtype=tf.int32)
            update_nodes = tf.nn.embedding_lookup(model.context_embeds,
                                                  tf.squeeze(test_ids))
            no_update_nodes = tf.nn.embedding_lookup(model.context_embeds,
                                                     tf.squeeze(train_ids))
            update_nodes = tf.scatter_nd(test_ids, update_nodes,
                                         tf.shape(model.context_embeds))
            no_update_nodes = tf.stop_gradient(
                tf.scatter_nd(train_ids, no_update_nodes,
                              tf.shape(model.context_embeds)))
            model.context_embeds = update_nodes + no_update_nodes
            sess.run(model.context_embeds)

            # run random walks
            from graphsage.utils import run_random_walks
            nodes = [
                n for n in G.nodes_iter()
                if G.node[n]["val"] or G.node[n]["test"]
            ]
            start_time = time.time()
            pairs = run_random_walks(G, nodes, num_walks=50)
            walk_time = time.time() - start_time

            test_minibatch = EdgeMinibatchIterator(
                G,
                id_map,
                placeholders,
                batch_size=FLAGS.batch_size,
                max_degree=FLAGS.max_degree,
                num_neg_samples=FLAGS.neg_sample_size,
                context_pairs=pairs,
                n2v_retrain=True,
                fixed_n2v=True)

            start_time = time.time()
            print("Doing test training for n2v.")
            test_steps = 0
            for epoch in range(FLAGS.n2v_test_epochs):
                test_minibatch.shuffle()
                while not test_minibatch.end():
                    feed_dict = test_minibatch.next_minibatch_feed_dict(
                        layer_infos)
                    feed_dict.update({placeholders['dropout']: FLAGS.dropout})
                    outs = sess.run([
                        model.opt_op, model.loss, model.ranks, model.aff_all,
                        model.mrr, model.outputs1
                    ],
                                    feed_dict=feed_dict)
                    if test_steps % FLAGS.print_every == 0:
                        print("Iter:", '%04d' % test_steps, "train_loss=",
                              "{:.5f}".format(outs[1]), "train_mrr=",
                              "{:.5f}".format(outs[-2]))
                    test_steps += 1
            train_time = time.time() - start_time
            save_val_embeddings(sess,
                                model,
                                minibatch,
                                FLAGS.validate_batch_size,
                                log_dir(),
                                mod="-test")
            print("Total time: ", train_time + walk_time)
            print("Walk time: ", walk_time)
            print("Train time: ", train_time)