def train(train_data):

    G = train_data[0]
    features = train_data[1]
    labels = train_data[2]
    train_nodes = train_data[3]
    test_nodes = train_data[4]
    val_nodes = train_data[5]
    num_classes = 2

    if not features is None:
        # pad with dummy zero vector
        features = np.vstack([features, np.zeros((features.shape[1], ))])

    placeholders = construct_placeholders(num_classes)
    minibatch = NodeMinibatchIterator(G,
                                      placeholders,
                                      labels,
                                      train_nodes,
                                      test_nodes,
                                      val_nodes,
                                      num_classes,
                                      batch_size=FLAGS.batch_size,
                                      max_degree=FLAGS.max_degree)
    adj_info_ph = tf.placeholder(tf.int32, shape=minibatch.adj.shape)
    adj_info = tf.Variable(adj_info_ph, trainable=False, name="adj_info")

    if FLAGS.model == 'graphsage_mean':
        # Create model
        sampler = UniformNeighborSampler(adj_info)
        if FLAGS.samples_3 != 0:
            layer_infos = [
                SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2),
                SAGEInfo("node", sampler, FLAGS.samples_3, FLAGS.dim_2)
            ]
        elif FLAGS.samples_2 != 0:
            layer_infos = [
                SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
            ]
        else:
            layer_infos = [
                SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1)
            ]

    model = SupervisedGraphsage(num_classes,
                                placeholders,
                                features,
                                adj_info,
                                minibatch.deg,
                                layer_infos,
                                model_size=FLAGS.model_size,
                                sigmoid_loss=FLAGS.sigmoid,
                                identity_dim=FLAGS.identity_dim,
                                logging=True)

    config = tf.ConfigProto(log_device_placement=FLAGS.log_device_placement)
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True

    # Initialize session
    sess = tf.Session(config=config)
    merged = tf.summary.merge_all()
    summary_writer = tf.summary.FileWriter(results_folder, sess.graph)

    # Init variables
    sess.run(tf.global_variables_initializer(),
             feed_dict={adj_info_ph: minibatch.adj})

    # Train model
    total_steps = 0
    avg_time = 0.0
    epoch_val_costs = []

    train_adj_info = tf.assign(adj_info, minibatch.adj)
    val_adj_info = tf.assign(adj_info, minibatch.test_adj)
    for epoch in range(FLAGS.epochs):
        minibatch.shuffle()

        iter = 0
        print('Epoch: %04d' % (epoch + 1))
        epoch_val_costs.append(0)
        while not minibatch.end():
            # Construct feed dictionary
            feed_dict, labels, batch = minibatch.next_minibatch_feed_dict()
            feed_dict.update({placeholders['dropout']: FLAGS.dropout})

            t = time.time()
            # Training step
            outs = sess.run([merged, model.opt_op, model.loss, model.preds],
                            feed_dict=feed_dict)
            train_cost = outs[2]

            if iter % FLAGS.validate_iter == 0:
                # Validation
                sess.run(val_adj_info.op)
                if FLAGS.validate_batch_size == -1:
                    val_cost, acc, prec, rec, f1_score, conf_mat = incremental_evaluate(
                        sess, model, minibatch, FLAGS.batch_size)
                else:
                    val_cost, acc, prec, rec, f1_score, conf_mat, fpr, tpr, thresholds = evaluate(
                        sess, model, minibatch, FLAGS.validate_batch_size)
                sess.run(train_adj_info.op)
                epoch_val_costs[-1] += val_cost

            if total_steps % FLAGS.print_every == 0:
                summary_writer.add_summary(outs[0], total_steps)

            # Print results
            avg_time = (avg_time * total_steps + time.time() -
                        t) / (total_steps + 1)

            if total_steps % FLAGS.print_every == 0:
                train_acc, train_prec, train_rec, train_f1_score, train_conf_mat, fpr, tpr, thresholds = eval(
                    labels, outs[-1])
                print("Iter:", '%04d' % iter, "train_loss=",
                      "{:.5f}".format(train_cost), "train_accuracy=",
                      "{:.5f}".format(train_acc), "train_precision=",
                      "{:.5f}".format(train_prec), "train_recall=",
                      "{:.5f}".format(train_rec), "train_f1_score=",
                      "{:.5f}".format(train_f1_score))
                with open(results_folder + "/train_stats.txt", "a") as fp:
                    fp.write(
                        "Iter:{:d} loss={:.5f} acc={:.5f} prec={:.5f} rec={:.5f} f1={:.5f} tp={:d} fp={:d} fn={:d} tn={:d}\n"
                        .format(iter, train_cost, train_acc, train_prec,
                                train_rec, train_f1_score,
                                train_conf_mat[0][0], train_conf_mat[0][1],
                                train_conf_mat[1][0], train_conf_mat[1][1]))

            iter += 1
            total_steps += 1

            if total_steps > int(FLAGS.max_total_steps):
                break

        if total_steps > int(FLAGS.max_total_steps):
            break

    print("Optimization Finished!")
    sess.run(val_adj_info.op)
    val_cost, val_acc, val_prec, val_rec, val_f1_score, val_conf_mat = incremental_evaluate(
        sess, model, minibatch, FLAGS.batch_size)
    print("Full validation stats:", "val_cost=", "{:.5f}".format(val_cost),
          "val_acc=", "{:.5f}".format(val_acc), "val_prec=",
          "{:.5f}".format(val_prec), "val_rec=", "{:.5f}".format(val_rec),
          "val_f1_score=", "{:.5f}".format(val_f1_score), "val_conf_mat=",
          val_conf_mat)
    with open(results_folder + "/val_stats.txt", "a") as fp:
        fp.write(
            "loss={:.5f} acc={:.5f} prec={:.5f} rec={:.5f} f1={:.5f} tp={:d} fp={:d} fn={:d} tn={:d} time=={:s}\n"
            .format(val_cost, val_acc, val_prec, val_rec, val_f1_score,
                    val_conf_mat[0][0], val_conf_mat[0][1], val_conf_mat[1][0],
                    val_conf_mat[1][1], current_time))
    print("Writing test set stats to file")
    test_cost, test_acc, test_prec, test_rec, test_f1_score, test_conf_mat = incremental_evaluate(
        sess, model, minibatch, FLAGS.batch_size, test=True)
    print("Full test stats:", "test_cost=", "{:.5f}".format(test_cost),
          "test_acc=", "{:.5f}".format(test_acc), "test_prec=",
          "{:.5f}".format(test_prec), "test_rec=", "{:.5f}".format(test_rec),
          "test_f1_score=", "{:.5f}".format(test_f1_score), "test_conf_mat=",
          test_conf_mat)
    with open(results_folder + "/test_stats.txt", "a") as fp:
        fp.write(
            "loss={:.5f} acc={:.5f} prec={:.5f} rec={:.5f} f1={:.5f} tp={:d} fp={:d} fn={:d} tn={:d} time=={:s}\n"
            .format(test_cost, test_acc, test_prec, test_rec, test_f1_score,
                    test_conf_mat[0][0], test_conf_mat[0][1],
                    test_conf_mat[1][0], test_conf_mat[1][1], current_time))
Exemplo n.º 2
0
def train(train_data, test_data=None):

    G = train_data[0]
    features = train_data[1]
    id_map = train_data[2]
    class_map = train_data[4]
    if isinstance(list(class_map.values())[0], list):
        num_classes = len(list(class_map.values())[0])
    else:
        num_classes = len(set(class_map.values()))

    if not features is None:
        # pad with dummy zero vector
        features = np.vstack([features, np.zeros((features.shape[1], ))])

    context_pairs = train_data[3] if FLAGS.random_context else None
    placeholders = construct_placeholders(num_classes)
    minibatch = NodeMinibatchIterator(G,
                                      id_map,
                                      placeholders,
                                      class_map,
                                      num_classes,
                                      batch_size=FLAGS.batch_size,
                                      max_degree=FLAGS.max_degree,
                                      context_pairs=context_pairs)
    adj_info_ph = tf.placeholder(tf.int32, shape=minibatch.adj.shape)
    adj_info = tf.Variable(adj_info_ph, trainable=False, name="adj_info")

    if FLAGS.model == 'graphsage_mean':
        # Create model
        sampler = UniformNeighborSampler(adj_info)
        if FLAGS.samples_3 != 0:
            layer_infos = [
                SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2),
                SAGEInfo("node", sampler, FLAGS.samples_3, FLAGS.dim_2)
            ]
        elif FLAGS.samples_2 != 0:
            layer_infos = [
                SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
            ]
        else:
            layer_infos = [
                SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1)
            ]

        model = SupervisedGraphsage(num_classes,
                                    placeholders,
                                    features,
                                    adj_info,
                                    minibatch.deg,
                                    layer_infos,
                                    model_size=FLAGS.model_size,
                                    sigmoid_loss=FLAGS.sigmoid,
                                    identity_dim=FLAGS.identity_dim,
                                    logging=True)
    elif FLAGS.model == 'gcn':
        # Create model
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, 2 * FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, 2 * FLAGS.dim_2)
        ]

        model = SupervisedGraphsage(num_classes,
                                    placeholders,
                                    features,
                                    adj_info,
                                    minibatch.deg,
                                    layer_infos=layer_infos,
                                    aggregator_type="gcn",
                                    model_size=FLAGS.model_size,
                                    concat=False,
                                    sigmoid_loss=FLAGS.sigmoid,
                                    identity_dim=FLAGS.identity_dim,
                                    logging=True)

    elif FLAGS.model == 'graphsage_seq':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SupervisedGraphsage(num_classes,
                                    placeholders,
                                    features,
                                    adj_info,
                                    minibatch.deg,
                                    layer_infos=layer_infos,
                                    aggregator_type="seq",
                                    model_size=FLAGS.model_size,
                                    sigmoid_loss=FLAGS.sigmoid,
                                    identity_dim=FLAGS.identity_dim,
                                    logging=True)

    elif FLAGS.model == 'graphsage_maxpool':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SupervisedGraphsage(num_classes,
                                    placeholders,
                                    features,
                                    adj_info,
                                    minibatch.deg,
                                    layer_infos=layer_infos,
                                    aggregator_type="maxpool",
                                    model_size=FLAGS.model_size,
                                    sigmoid_loss=FLAGS.sigmoid,
                                    identity_dim=FLAGS.identity_dim,
                                    logging=True)

    elif FLAGS.model == 'graphsage_meanpool':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SupervisedGraphsage(num_classes,
                                    placeholders,
                                    features,
                                    adj_info,
                                    minibatch.deg,
                                    layer_infos=layer_infos,
                                    aggregator_type="meanpool",
                                    model_size=FLAGS.model_size,
                                    sigmoid_loss=FLAGS.sigmoid,
                                    identity_dim=FLAGS.identity_dim,
                                    logging=True)

    else:
        raise Exception('Error: model name unrecognized.')

    config = tf.ConfigProto(log_device_placement=FLAGS.log_device_placement)
    config.gpu_options.allow_growth = True
    #config.gpu_options.per_process_gpu_memory_fraction = GPU_MEM_FRACTION
    config.allow_soft_placement = True

    # Initialize session
    sess = tf.Session(config=config)
    merged = tf.summary.merge_all()
    summary_writer = tf.summary.FileWriter(log_dir(), sess.graph)

    # Init variables
    sess.run(tf.global_variables_initializer(),
             feed_dict={adj_info_ph: minibatch.adj})

    # Train model

    total_steps = 0
    avg_time = 0.0
    epoch_val_costs = []

    train_adj_info = tf.assign(adj_info, minibatch.adj)
    val_adj_info = tf.assign(adj_info, minibatch.test_adj)
    for epoch in range(FLAGS.epochs):
        minibatch.shuffle()

        iter = 0
        print('Epoch: %04d' % (epoch + 1))
        epoch_val_costs.append(0)
        while not minibatch.end():
            # Construct feed dictionary
            feed_dict, labels = minibatch.next_minibatch_feed_dict()
            feed_dict.update({placeholders['dropout']: FLAGS.dropout})

            t = time.time()
            # Training step
            outs = sess.run([merged, model.opt_op, model.loss, model.preds],
                            feed_dict=feed_dict)
            train_cost = outs[2]

            if iter % FLAGS.validate_iter == 0:
                # Validation
                sess.run(val_adj_info.op)
                if FLAGS.validate_batch_size == -1:
                    val_cost, val_f1_mic, val_f1_mac, duration = incremental_evaluate(
                        sess, model, minibatch, FLAGS.batch_size)
                else:
                    val_cost, val_f1_mic, val_f1_mac, duration = evaluate(
                        sess, model, minibatch, FLAGS.validate_batch_size)
                sess.run(train_adj_info.op)
                epoch_val_costs[-1] += val_cost

            if total_steps % FLAGS.print_every == 0:
                summary_writer.add_summary(outs[0], total_steps)

            # Print results
            avg_time = (avg_time * total_steps + time.time() -
                        t) / (total_steps + 1)

            if total_steps % FLAGS.print_every == 0:
                train_f1_mic, train_f1_mac = calc_f1(labels, outs[-1])
                print("Iter:", '%04d' % iter, "train_loss=",
                      "{:.5f}".format(train_cost), "train_f1_mic=",
                      "{:.5f}".format(train_f1_mic), "train_f1_mac=",
                      "{:.5f}".format(train_f1_mac), "val_loss=",
                      "{:.5f}".format(val_cost), "val_f1_mic=",
                      "{:.5f}".format(val_f1_mic), "val_f1_mac=",
                      "{:.5f}".format(val_f1_mac), "time=",
                      "{:.5f}".format(avg_time))

            iter += 1
            total_steps += 1

            if total_steps > FLAGS.max_total_steps:
                break

        if total_steps > FLAGS.max_total_steps:
            break

    print("Optimization Finished!")
    sess.run(val_adj_info.op)
    val_cost, val_f1_mic, val_f1_mac, duration = incremental_evaluate(
        sess, model, minibatch, FLAGS.batch_size)
    print("Full validation stats:", "loss=", "{:.5f}".format(val_cost),
          "f1_micro=", "{:.5f}".format(val_f1_mic), "f1_macro=",
          "{:.5f}".format(val_f1_mac), "time=", "{:.5f}".format(duration))
    with open(log_dir() + "val_stats.txt", "w") as fp:
        fp.write(
            "loss={:.5f} f1_micro={:.5f} f1_macro={:.5f} time={:.5f}".format(
                val_cost, val_f1_mic, val_f1_mac, duration))

    print("Writing test set stats to file (don't peak!)")
    val_cost, val_f1_mic, val_f1_mac, duration = incremental_evaluate(
        sess, model, minibatch, FLAGS.batch_size, test=True)
    with open(log_dir() + "test_stats.txt", "w") as fp:
        fp.write("loss={:.5f} f1_micro={:.5f} f1_macro={:.5f}".format(
            val_cost, val_f1_mic, val_f1_mac))
    dicoIdMap[node] = compteur
    compteur += 1
id_map = dicoIdMap

minibatch = EdgeMinibatchIterator(G,
                                  edgelist,
                                  test_edgelist,
                                  id_map,
                                  batch_size=100,
                                  max_degree=3)

sampler = UniformNeighborSampler(minibatch.adj)

layer_infos = [
    SAGEInfo("node", sampler, 3, 128),
    SAGEInfo("node", sampler, 3, 256)
]

features = []

maximum = -1
"""
for node in G.nodes: 

    text_embed = dicoEmbT[str(node)]
    maximum = max(tf.convert_to_tensor(text_embed.detach().numpy()).shape[1], maximum)
print("MAXIMUM", maximum)
for node in G.nodes:
    
    
Exemplo n.º 4
0
    def _create_model(self, num_classes, placeholders, features, adj_info,
                      minibatch):
        if self.model_name == 'graphsage_mean':
            # Create model
            sampler = UniformNeighborSampler(adj_info)
            if self.samples_3 != 0:
                layer_infos = [
                    SAGEInfo("node", sampler, self.samples_1, self.dim_1),
                    SAGEInfo("node", sampler, self.samples_2, self.dim_2),
                    SAGEInfo("node", sampler, self.samples_3, self.dim_2)
                ]
            elif self.samples_2 != 0:
                layer_infos = [
                    SAGEInfo("node", sampler, self.samples_1, self.dim_1),
                    SAGEInfo("node", sampler, self.samples_2, self.dim_2)
                ]
            else:
                layer_infos = [
                    SAGEInfo("node", sampler, self.samples_1, self.dim_1)
                ]

            model = SupervisedGraphsage(num_classes,
                                        placeholders,
                                        features,
                                        adj_info,
                                        minibatch.deg,
                                        layer_infos,
                                        weight_decay=self.weight_decay,
                                        learning_rate=self.learning_rate,
                                        model_size=self.model_size,
                                        sigmoid_loss=self.sigmoid,
                                        identity_dim=self.identity_dim,
                                        logging=True)
        elif self.model_name == 'gcn':
            # Create model
            sampler = UniformNeighborSampler(adj_info)
            layer_infos = [
                SAGEInfo("node", sampler, self.samples_1, 2 * self.dim_1),
                SAGEInfo("node", sampler, self.samples_2, 2 * self.dim_2)
            ]

            model = SupervisedGraphsage(num_classes,
                                        placeholders,
                                        features,
                                        adj_info,
                                        minibatch.deg,
                                        layer_infos=layer_infos,
                                        weight_decay=self.weight_decay,
                                        learning_rate=self.learning_rate,
                                        aggregator_type="gcn",
                                        model_size=self.model_size,
                                        concat=False,
                                        sigmoid_loss=self.sigmoid,
                                        identity_dim=self.identity_dim,
                                        logging=True)
        elif self.model_name == 'graphsage_seq':
            sampler = UniformNeighborSampler(adj_info)
            layer_infos = [
                SAGEInfo("node", sampler, self.samples_1, self.dim_1),
                SAGEInfo("node", sampler, self.samples_2, self.dim_2)
            ]

            model = SupervisedGraphsage(num_classes,
                                        placeholders,
                                        features,
                                        adj_info,
                                        minibatch.deg,
                                        layer_infos=layer_infos,
                                        weight_decay=self.weight_decay,
                                        learning_rate=self.learning_rate,
                                        aggregator_type="seq",
                                        model_size=self.model_size,
                                        sigmoid_loss=self.sigmoid,
                                        identity_dim=self.identity_dim,
                                        logging=True)
        elif self.model_name == 'graphsage_maxpool':
            sampler = UniformNeighborSampler(adj_info)
            layer_infos = [
                SAGEInfo("node", sampler, self.samples_1, self.dim_1),
                SAGEInfo("node", sampler, self.samples_2, self.dim_2)
            ]

            model = SupervisedGraphsage(num_classes,
                                        placeholders,
                                        features,
                                        adj_info,
                                        minibatch.deg,
                                        layer_infos=layer_infos,
                                        weight_decay=self.weight_decay,
                                        learning_rate=self.learning_rate,
                                        aggregator_type="maxpool",
                                        model_size=self.model_size,
                                        sigmoid_loss=self.sigmoid,
                                        identity_dim=self.identity_dim,
                                        logging=True)
        elif self.model_name == 'graphsage_meanpool':
            sampler = UniformNeighborSampler(adj_info)
            layer_infos = [
                SAGEInfo("node", sampler, self.samples_1, self.dim_1),
                SAGEInfo("node", sampler, self.samples_2, self.dim_2)
            ]

            model = SupervisedGraphsage(num_classes,
                                        placeholders,
                                        features,
                                        adj_info,
                                        minibatch.deg,
                                        layer_infos=layer_infos,
                                        weight_decay=self.weight_decay,
                                        learning_rate=self.learning_rate,
                                        aggregator_type="meanpool",
                                        model_size=self.model_size,
                                        sigmoid_loss=self.sigmoid,
                                        identity_dim=self.identity_dim,
                                        logging=True)
        else:
            raise Exception('Error: model name unrecognized.')
        return model
Exemplo n.º 5
0
def train(train_data, test_data=None):

    # adj, features, y_train, y_val, y_test, train_mask, val_mask, test_mask

    G = train_data[0]
    features = train_data[1]
    if not features is None:
        # pad with dummy zero vector
        features = np.vstack([features, np.zeros((features.shape[1], ))])

    placeholders = construct_placeholders()
    minibatch = EdgeMinibatchIterator(G,
                                      placeholders,
                                      batch_size=FLAGS.batch_size,
                                      max_degree=FLAGS.max_degree,
                                      num_neg_samples=FLAGS.neg_sample_size)
    adj_info_ph = tf.placeholder(tf.int32, shape=minibatch.adj.shape)
    adj_info = tf.Variable(adj_info_ph, trainable=False, name="adj_info")

    if FLAGS.model == 'graphsage_mean':
        # Create model
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   logging=True)
    elif FLAGS.model == 'gcn':
        # Create model
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, 2 * FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, 2 * FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   aggregator_type="gcn",
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   concat=False,
                                   logging=True)

    elif FLAGS.model == 'graphsage_seq':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   identity_dim=FLAGS.identity_dim,
                                   aggregator_type="seq",
                                   model_size=FLAGS.model_size,
                                   logging=True)

    elif FLAGS.model == 'graphsage_maxpool':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   aggregator_type="maxpool",
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   logging=True)
    elif FLAGS.model == 'graphsage_meanpool':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   aggregator_type="meanpool",
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   logging=True)
    else:
        raise Exception('Error: model name unrecognized.')

    config = tf.ConfigProto(log_device_placement=FLAGS.log_device_placement)
    config.gpu_options.allow_growth = True
    #config.gpu_options.per_process_gpu_memory_fraction = GPU_MEM_FRACTION
    config.allow_soft_placement = True

    # Initialize session
    sess = tf.Session(config=config)
    merged = tf.summary.merge_all()
    summary_writer = tf.summary.FileWriter(log_dir(), sess.graph)

    # Init variables
    sess.run(tf.global_variables_initializer(),
             feed_dict={adj_info_ph: minibatch.adj})

    # Train model

    train_shadow_mrr = None
    shadow_mrr = None

    total_steps = 0
    avg_time = 0.0
    epoch_val_costs = []

    train_adj_info = tf.assign(adj_info, minibatch.adj)
    for epoch in range(FLAGS.epochs):
        minibatch.shuffle()

        iter = 0
        print('Epoch: %04d' % (epoch + 1))
        epoch_val_costs.append(0)
        while not minibatch.end():
            # Construct feed dictionary
            feed_dict = minibatch.next_minibatch_feed_dict()
            feed_dict.update({placeholders['dropout']: FLAGS.dropout})

            t = time.time()
            # Training step
            outs = sess.run([
                merged, model.opt_op, model.loss, model.ranks, model.aff_all,
                model.mrr, model.outputs1
            ],
                            feed_dict=feed_dict)
            train_cost = outs[2]
            train_mrr = outs[5]
            if train_shadow_mrr is None:
                train_shadow_mrr = train_mrr  #
            else:
                train_shadow_mrr -= (1 - 0.99) * (train_shadow_mrr - train_mrr)

            # Print results
            avg_time = (avg_time * total_steps + time.time() -
                        t) / (total_steps + 1)

            if total_steps % FLAGS.print_every == 0:
                print("Iter:", '%04d' % iter, "train_loss=",
                      "{:.5f}".format(train_cost), "train_mrr=",
                      "{:.5f}".format(train_mrr), "train_mrr_ema=",
                      "{:.5f}".format(train_shadow_mrr))

            iter += 1
            total_steps += 1

            if total_steps > FLAGS.max_total_steps:
                break

        if total_steps > FLAGS.max_total_steps:
            break

    print("Optimization Finished!")
    if FLAGS.save_embeddings:
        save_val_embeddings(sess, model, minibatch, FLAGS.validate_batch_size,
                            log_dir())
def train(train_data, test_data=None):
    G = train_data[0]
    features = train_data[1]
    id_map = train_data[2]

    if not features is None:
        # pad with dummy zero vector
        features = np.vstack([features, np.zeros((features.shape[1], ))])

    context_pairs = train_data[3] if FLAGS.random_context else None
    placeholders = construct_placeholders()
    minibatch = EdgeMinibatchIterator(G,
                                      id_map,
                                      placeholders,
                                      batch_size=FLAGS.batch_size,
                                      max_degree=FLAGS.max_degree,
                                      num_neg_samples=FLAGS.neg_sample_size,
                                      context_pairs=context_pairs)
    adj_info_ph = tf.placeholder(tf.int32, shape=minibatch.adj.shape)
    adj_info = tf.Variable(adj_info_ph, trainable=False, name="adj_info")

    if FLAGS.model == 'graphsage_mean':
        # Create model
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   logging=True)
    elif FLAGS.model == 'gcn':
        # Create model
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, 2 * FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, 2 * FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   aggregator_type="gcn",
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   concat=False,
                                   logging=True)

    elif FLAGS.model == 'graphsage_seq':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   identity_dim=FLAGS.identity_dim,
                                   aggregator_type="seq",
                                   model_size=FLAGS.model_size,
                                   logging=True)

    elif FLAGS.model == 'graphsage_maxpool':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   aggregator_type="maxpool",
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   logging=True)
    elif FLAGS.model == 'graphsage_meanpool':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   aggregator_type="meanpool",
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   logging=True)

    elif FLAGS.model == 'n2v':
        model = Node2VecModel(
            placeholders,
            features.shape[0],
            minibatch.deg,
            #2x because graphsage uses concat
            nodevec_dim=2 * FLAGS.dim_1,
            lr=FLAGS.learning_rate)
    else:
        raise Exception('Error: model name unrecognized.')

    config = tf.ConfigProto(log_device_placement=FLAGS.log_device_placement)
    config.gpu_options.allow_growth = True
    #config.gpu_options.per_process_gpu_memory_fraction = GPU_MEM_FRACTION
    config.allow_soft_placement = True

    # Initialize session
    sess = tf.Session(config=config)
    merged = tf.summary.merge_all()
    summary_writer = tf.summary.FileWriter(log_dir(), sess.graph)

    # Init variables
    sess.run(tf.global_variables_initializer(),
             feed_dict={adj_info_ph: minibatch.adj})

    # Train model

    train_shadow_mrr = None
    shadow_mrr = None

    total_steps = 0
    avg_time = 0.0
    epoch_val_costs = []

    train_adj_info = tf.assign(adj_info, minibatch.adj)
    val_adj_info = tf.assign(adj_info, minibatch.test_adj)
    for epoch in range(FLAGS.epochs):
        minibatch.shuffle()

        iter = 0
        print('Epoch: %04d' % (epoch + 1))
        epoch_val_costs.append(0)
        while not minibatch.end():
            # Construct feed dictionary
            feed_dict = minibatch.next_minibatch_feed_dict()
            feed_dict.update({placeholders['dropout']: FLAGS.dropout})

            t = time.time()
            # Training step
            outs = sess.run([
                merged, model.opt_op, model.loss, model.ranks, model.aff_all,
                model.mrr, model.outputs1
            ],
                            feed_dict=feed_dict)
            train_cost = outs[2]
            train_mrr = outs[5]
            if train_shadow_mrr is None:
                train_shadow_mrr = train_mrr  #
            else:
                train_shadow_mrr -= (1 - 0.99) * (train_shadow_mrr - train_mrr)

            if iter % FLAGS.validate_iter == 0:
                # Validation
                sess.run(val_adj_info.op)
                val_cost, ranks, val_mrr, duration = evaluate(
                    sess, model, minibatch, size=FLAGS.validate_batch_size)
                sess.run(train_adj_info.op)
                epoch_val_costs[-1] += val_cost
            if shadow_mrr is None:
                shadow_mrr = val_mrr
            else:
                shadow_mrr -= (1 - 0.99) * (shadow_mrr - val_mrr)

            if total_steps % FLAGS.print_every == 0:
                summary_writer.add_summary(outs[0], total_steps)

            # Print results
            avg_time = (avg_time * total_steps + time.time() -
                        t) / (total_steps + 1)

            if total_steps % FLAGS.print_every == 0:
                print(
                    "Iter:",
                    '%04d' % iter,
                    "train_loss=",
                    "{:.5f}".format(train_cost),
                    "train_mrr=",
                    "{:.5f}".format(train_mrr),
                    "train_mrr_ema=",
                    "{:.5f}".format(
                        train_shadow_mrr),  # exponential moving average
                    "val_loss=",
                    "{:.5f}".format(val_cost),
                    "val_mrr=",
                    "{:.5f}".format(val_mrr),
                    "val_mrr_ema=",
                    "{:.5f}".format(shadow_mrr),  # exponential moving average
                    "time=",
                    "{:.5f}".format(avg_time))

            iter += 1
            total_steps += 1

            if total_steps > FLAGS.max_total_steps:
                break

        if total_steps > FLAGS.max_total_steps:
            break

    print("Optimization Finished!")
    if FLAGS.save_embeddings:
        sess.run(val_adj_info.op)

        save_val_embeddings(sess, model, minibatch, FLAGS.validate_batch_size,
                            log_dir())

        if FLAGS.model == "n2v":
            # stopping the gradient for the already trained nodes
            train_ids = tf.constant(
                [[id_map[n]] for n in G.nodes_iter()
                 if not G.node[n]['val'] and not G.node[n]['test']],
                dtype=tf.int32)
            test_ids = tf.constant([[id_map[n]] for n in G.nodes_iter()
                                    if G.node[n]['val'] or G.node[n]['test']],
                                   dtype=tf.int32)
            update_nodes = tf.nn.embedding_lookup(model.context_embeds,
                                                  tf.squeeze(test_ids))
            no_update_nodes = tf.nn.embedding_lookup(model.context_embeds,
                                                     tf.squeeze(train_ids))
            update_nodes = tf.scatter_nd(test_ids, update_nodes,
                                         tf.shape(model.context_embeds))
            no_update_nodes = tf.stop_gradient(
                tf.scatter_nd(train_ids, no_update_nodes,
                              tf.shape(model.context_embeds)))
            model.context_embeds = update_nodes + no_update_nodes
            sess.run(model.context_embeds)

            # run random walks
            from graphsage.utils import run_random_walks
            nodes = [
                n for n in G.nodes_iter()
                if G.node[n]["val"] or G.node[n]["test"]
            ]
            start_time = time.time()
            pairs = run_random_walks(G, nodes, num_walks=50)
            walk_time = time.time() - start_time

            test_minibatch = EdgeMinibatchIterator(
                G,
                id_map,
                placeholders,
                batch_size=FLAGS.batch_size,
                max_degree=FLAGS.max_degree,
                num_neg_samples=FLAGS.neg_sample_size,
                context_pairs=pairs,
                n2v_retrain=True,
                fixed_n2v=True)

            start_time = time.time()
            print("Doing test training for n2v.")
            test_steps = 0
            for epoch in range(FLAGS.n2v_test_epochs):
                test_minibatch.shuffle()
                while not test_minibatch.end():
                    feed_dict = test_minibatch.next_minibatch_feed_dict()
                    feed_dict.update({placeholders['dropout']: FLAGS.dropout})
                    outs = sess.run([
                        model.opt_op, model.loss, model.ranks, model.aff_all,
                        model.mrr, model.outputs1
                    ],
                                    feed_dict=feed_dict)
                    if test_steps % FLAGS.print_every == 0:
                        print("Iter:", '%04d' % test_steps, "train_loss=",
                              "{:.5f}".format(outs[1]), "train_mrr=",
                              "{:.5f}".format(outs[-2]))
                    test_steps += 1
            train_time = time.time() - start_time
            save_val_embeddings(sess,
                                model,
                                minibatch,
                                FLAGS.validate_batch_size,
                                log_dir(),
                                mod="-test")
            print("Total time: ", train_time + walk_time)
            print("Walk time: ", walk_time)
            print("Train time: ", train_time)
 def _create_model(self, sampler_name, placeholders, features, adj_info,
                   minibatch):
     if self.model_name == 'mean_concat':
         # Create model
         sampler = self._create_sampler(sampler_name, adj_info, features)
         layer_infos = [
             SAGEInfo("node", sampler, self.samples_1, self.dim_1),
             SAGEInfo("node", sampler, self.samples_2, self.dim_2)
         ]
         model = SampleAndAggregate(placeholders,
                                    features,
                                    adj_info,
                                    minibatch.deg,
                                    concat=True,
                                    layer_infos=layer_infos,
                                    weight_decay=self.weight_decay,
                                    learning_rate=self.learning_rate,
                                    neg_sample_size=self.neg_sample_size,
                                    batch_size=self.batch_size,
                                    model_size=self.model_size,
                                    identity_dim=self.identity_dim,
                                    logging=True)
     elif self.model_name == 'mean_add':
         # Create model
         sampler = self._create_sampler(sampler_name, adj_info, features)
         layer_infos = [
             SAGEInfo("node", sampler, self.samples_1, self.dim_1),
             SAGEInfo("node", sampler, self.samples_2, self.dim_2)
         ]
         model = SampleAndAggregate(placeholders,
                                    features,
                                    adj_info,
                                    minibatch.deg,
                                    layer_infos=layer_infos,
                                    weight_decay=self.weight_decay,
                                    learning_rate=self.learning_rate,
                                    neg_sample_size=self.neg_sample_size,
                                    batch_size=self.batch_size,
                                    concat=False,
                                    model_size=self.model_size,
                                    identity_dim=self.identity_dim,
                                    logging=True)
     elif self.model_name == 'gcn':
         # Create model
         sampler = self._create_sampler(sampler_name, adj_info, features)
         layer_infos = [
             SAGEInfo("node", sampler, self.samples_1, 2 * self.dim_1),
             SAGEInfo("node", sampler, self.samples_2, 2 * self.dim_2)
         ]
         model = SampleAndAggregate(placeholders,
                                    features,
                                    adj_info,
                                    minibatch.deg,
                                    layer_infos=layer_infos,
                                    weight_decay=self.weight_decay,
                                    learning_rate=self.learning_rate,
                                    neg_sample_size=self.neg_sample_size,
                                    batch_size=self.batch_size,
                                    aggregator_type="gcn",
                                    model_size=self.model_size,
                                    identity_dim=self.identity_dim,
                                    concat=False,
                                    logging=True)
     elif self.model_name == 'graphsage_seq':
         # Create model
         sampler = self._create_sampler(sampler_name, adj_info, features)
         layer_infos = [
             SAGEInfo("node", sampler, self.samples_1, self.dim_1),
             SAGEInfo("node", sampler, self.samples_2, self.dim_2)
         ]
         model = SampleAndAggregate(placeholders,
                                    features,
                                    adj_info,
                                    minibatch.deg,
                                    layer_infos=layer_infos,
                                    weight_decay=self.weight_decay,
                                    learning_rate=self.learning_rate,
                                    neg_sample_size=self.neg_sample_size,
                                    batch_size=self.batch_size,
                                    identity_dim=self.identity_dim,
                                    aggregator_type="seq",
                                    model_size=self.model_size,
                                    logging=True)
     elif self.model_name == 'graphsage_maxpool':
         # Create model
         sampler = self._create_sampler(sampler_name, adj_info, features)
         layer_infos = [
             SAGEInfo("node", sampler, self.samples_1, self.dim_1),
             SAGEInfo("node", sampler, self.samples_2, self.dim_2)
         ]
         model = SampleAndAggregate(placeholders,
                                    features,
                                    adj_info,
                                    minibatch.deg,
                                    layer_infos=layer_infos,
                                    weight_decay=self.weight_decay,
                                    learning_rate=self.learning_rate,
                                    neg_sample_size=self.neg_sample_size,
                                    batch_size=self.batch_size,
                                    aggregator_type="maxpool",
                                    model_size=self.model_size,
                                    identity_dim=self.identity_dim,
                                    logging=True)
     elif self.model_name == 'graphsage_meanpool':
         # Create model
         sampler = self._create_sampler(sampler_name, adj_info, features)
         layer_infos = [
             SAGEInfo("node", sampler, self.samples_1, self.dim_1),
             SAGEInfo("node", sampler, self.samples_2, self.dim_2)
         ]
         model = SampleAndAggregate(placeholders,
                                    features,
                                    adj_info,
                                    minibatch.deg,
                                    layer_infos=layer_infos,
                                    weight_decay=self.weight_decay,
                                    learning_rate=self.learning_rate,
                                    neg_sample_size=self.neg_sample_size,
                                    batch_size=self.batch_size,
                                    aggregator_type="meanpool",
                                    model_size=self.model_size,
                                    identity_dim=self.identity_dim,
                                    logging=True)
     else:
         raise Exception('Error: model name unrecognized.')
     return model