def _create_sampler(self, sampler_name, adj_info, features):
     if sampler_name == 'Uniform':
         sampler = UniformNeighborSampler(adj_info)
     elif sampler_name == 'ML':
         sampler = MLNeighborSampler(adj_info, features, self.max_degree,
                                     self.nonlinear_sampler)
     elif sampler_name == 'FastML':
         sampler = FastMLNeighborSampler(adj_info, features,
                                         self.max_degree,
                                         self.nonlinear_sampler)
     else:
         raise Exception('Error: sampler name unrecognized.')
     return sampler
예제 #2
0
def train(train_data, test_data=None):

    G = train_data[0]
    features = train_data[1]
    id_map = train_data[2]
    class_map = train_data[4]
    if isinstance(list(class_map.values())[0], list):
        num_classes = len(list(class_map.values())[0])
    else:
        num_classes = len(set(class_map.values()))

    if not features is None:
        # pad with dummy zero vector
        features = np.vstack([features, np.zeros((features.shape[1], ))])

    context_pairs = train_data[3] if FLAGS.random_context else None
    placeholders = construct_placeholders(num_classes)
    minibatch = NodeMinibatchIterator(G,
                                      id_map,
                                      placeholders,
                                      class_map,
                                      num_classes,
                                      batch_size=FLAGS.batch_size,
                                      max_degree=FLAGS.max_degree,
                                      context_pairs=context_pairs)
    adj_info_ph = tf.placeholder(tf.int32, shape=minibatch.adj.shape)
    adj_info = tf.Variable(adj_info_ph, trainable=False, name="adj_info")

    if FLAGS.model == 'graphsage_mean':
        # Create model
        sampler = UniformNeighborSampler(adj_info)
        if FLAGS.samples_3 != 0:
            layer_infos = [
                SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2),
                SAGEInfo("node", sampler, FLAGS.samples_3, FLAGS.dim_2)
            ]
        elif FLAGS.samples_2 != 0:
            layer_infos = [
                SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
            ]
        else:
            layer_infos = [
                SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1)
            ]

        model = SupervisedGraphsage(num_classes,
                                    placeholders,
                                    features,
                                    adj_info,
                                    minibatch.deg,
                                    layer_infos,
                                    model_size=FLAGS.model_size,
                                    sigmoid_loss=FLAGS.sigmoid,
                                    identity_dim=FLAGS.identity_dim,
                                    logging=True)
    elif FLAGS.model == 'gcn':
        # Create model
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, 2 * FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, 2 * FLAGS.dim_2)
        ]

        model = SupervisedGraphsage(num_classes,
                                    placeholders,
                                    features,
                                    adj_info,
                                    minibatch.deg,
                                    layer_infos=layer_infos,
                                    aggregator_type="gcn",
                                    model_size=FLAGS.model_size,
                                    concat=False,
                                    sigmoid_loss=FLAGS.sigmoid,
                                    identity_dim=FLAGS.identity_dim,
                                    logging=True)

    elif FLAGS.model == 'graphsage_seq':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SupervisedGraphsage(num_classes,
                                    placeholders,
                                    features,
                                    adj_info,
                                    minibatch.deg,
                                    layer_infos=layer_infos,
                                    aggregator_type="seq",
                                    model_size=FLAGS.model_size,
                                    sigmoid_loss=FLAGS.sigmoid,
                                    identity_dim=FLAGS.identity_dim,
                                    logging=True)

    elif FLAGS.model == 'graphsage_maxpool':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SupervisedGraphsage(num_classes,
                                    placeholders,
                                    features,
                                    adj_info,
                                    minibatch.deg,
                                    layer_infos=layer_infos,
                                    aggregator_type="maxpool",
                                    model_size=FLAGS.model_size,
                                    sigmoid_loss=FLAGS.sigmoid,
                                    identity_dim=FLAGS.identity_dim,
                                    logging=True)

    elif FLAGS.model == 'graphsage_meanpool':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SupervisedGraphsage(num_classes,
                                    placeholders,
                                    features,
                                    adj_info,
                                    minibatch.deg,
                                    layer_infos=layer_infos,
                                    aggregator_type="meanpool",
                                    model_size=FLAGS.model_size,
                                    sigmoid_loss=FLAGS.sigmoid,
                                    identity_dim=FLAGS.identity_dim,
                                    logging=True)

    else:
        raise Exception('Error: model name unrecognized.')

    config = tf.ConfigProto(log_device_placement=FLAGS.log_device_placement)
    config.gpu_options.allow_growth = True
    #config.gpu_options.per_process_gpu_memory_fraction = GPU_MEM_FRACTION
    config.allow_soft_placement = True

    # Initialize session
    sess = tf.Session(config=config)
    merged = tf.summary.merge_all()
    summary_writer = tf.summary.FileWriter(log_dir(), sess.graph)

    # Init variables
    sess.run(tf.global_variables_initializer(),
             feed_dict={adj_info_ph: minibatch.adj})

    # Train model

    total_steps = 0
    avg_time = 0.0
    epoch_val_costs = []

    train_adj_info = tf.assign(adj_info, minibatch.adj)
    val_adj_info = tf.assign(adj_info, minibatch.test_adj)
    for epoch in range(FLAGS.epochs):
        minibatch.shuffle()

        iter = 0
        print('Epoch: %04d' % (epoch + 1))
        epoch_val_costs.append(0)
        while not minibatch.end():
            # Construct feed dictionary
            feed_dict, labels = minibatch.next_minibatch_feed_dict()
            feed_dict.update({placeholders['dropout']: FLAGS.dropout})

            t = time.time()
            # Training step
            outs = sess.run([merged, model.opt_op, model.loss, model.preds],
                            feed_dict=feed_dict)
            train_cost = outs[2]

            if iter % FLAGS.validate_iter == 0:
                # Validation
                sess.run(val_adj_info.op)
                if FLAGS.validate_batch_size == -1:
                    val_cost, val_f1_mic, val_f1_mac, duration = incremental_evaluate(
                        sess, model, minibatch, FLAGS.batch_size)
                else:
                    val_cost, val_f1_mic, val_f1_mac, duration = evaluate(
                        sess, model, minibatch, FLAGS.validate_batch_size)
                sess.run(train_adj_info.op)
                epoch_val_costs[-1] += val_cost

            if total_steps % FLAGS.print_every == 0:
                summary_writer.add_summary(outs[0], total_steps)

            # Print results
            avg_time = (avg_time * total_steps + time.time() -
                        t) / (total_steps + 1)

            if total_steps % FLAGS.print_every == 0:
                train_f1_mic, train_f1_mac = calc_f1(labels, outs[-1])
                print("Iter:", '%04d' % iter, "train_loss=",
                      "{:.5f}".format(train_cost), "train_f1_mic=",
                      "{:.5f}".format(train_f1_mic), "train_f1_mac=",
                      "{:.5f}".format(train_f1_mac), "val_loss=",
                      "{:.5f}".format(val_cost), "val_f1_mic=",
                      "{:.5f}".format(val_f1_mic), "val_f1_mac=",
                      "{:.5f}".format(val_f1_mac), "time=",
                      "{:.5f}".format(avg_time))

            iter += 1
            total_steps += 1

            if total_steps > FLAGS.max_total_steps:
                break

        if total_steps > FLAGS.max_total_steps:
            break

    print("Optimization Finished!")
    sess.run(val_adj_info.op)
    val_cost, val_f1_mic, val_f1_mac, duration = incremental_evaluate(
        sess, model, minibatch, FLAGS.batch_size)
    print("Full validation stats:", "loss=", "{:.5f}".format(val_cost),
          "f1_micro=", "{:.5f}".format(val_f1_mic), "f1_macro=",
          "{:.5f}".format(val_f1_mac), "time=", "{:.5f}".format(duration))
    with open(log_dir() + "val_stats.txt", "w") as fp:
        fp.write(
            "loss={:.5f} f1_micro={:.5f} f1_macro={:.5f} time={:.5f}".format(
                val_cost, val_f1_mic, val_f1_mac, duration))

    print("Writing test set stats to file (don't peak!)")
    val_cost, val_f1_mic, val_f1_mac, duration = incremental_evaluate(
        sess, model, minibatch, FLAGS.batch_size, test=True)
    with open(log_dir() + "test_stats.txt", "w") as fp:
        fp.write("loss={:.5f} f1_micro={:.5f} f1_macro={:.5f}".format(
            val_cost, val_f1_mic, val_f1_mac))
def train(train_data):

    G = train_data[0]
    features = train_data[1]
    labels = train_data[2]
    train_nodes = train_data[3]
    test_nodes = train_data[4]
    val_nodes = train_data[5]
    num_classes = 2

    if not features is None:
        # pad with dummy zero vector
        features = np.vstack([features, np.zeros((features.shape[1], ))])

    placeholders = construct_placeholders(num_classes)
    minibatch = NodeMinibatchIterator(G,
                                      placeholders,
                                      labels,
                                      train_nodes,
                                      test_nodes,
                                      val_nodes,
                                      num_classes,
                                      batch_size=FLAGS.batch_size,
                                      max_degree=FLAGS.max_degree)
    adj_info_ph = tf.placeholder(tf.int32, shape=minibatch.adj.shape)
    adj_info = tf.Variable(adj_info_ph, trainable=False, name="adj_info")

    if FLAGS.model == 'graphsage_mean':
        # Create model
        sampler = UniformNeighborSampler(adj_info)
        if FLAGS.samples_3 != 0:
            layer_infos = [
                SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2),
                SAGEInfo("node", sampler, FLAGS.samples_3, FLAGS.dim_2)
            ]
        elif FLAGS.samples_2 != 0:
            layer_infos = [
                SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
            ]
        else:
            layer_infos = [
                SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1)
            ]

    model = SupervisedGraphsage(num_classes,
                                placeholders,
                                features,
                                adj_info,
                                minibatch.deg,
                                layer_infos,
                                model_size=FLAGS.model_size,
                                sigmoid_loss=FLAGS.sigmoid,
                                identity_dim=FLAGS.identity_dim,
                                logging=True)

    config = tf.ConfigProto(log_device_placement=FLAGS.log_device_placement)
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True

    # Initialize session
    sess = tf.Session(config=config)
    merged = tf.summary.merge_all()
    summary_writer = tf.summary.FileWriter(results_folder, sess.graph)

    # Init variables
    sess.run(tf.global_variables_initializer(),
             feed_dict={adj_info_ph: minibatch.adj})

    # Train model
    total_steps = 0
    avg_time = 0.0
    epoch_val_costs = []

    train_adj_info = tf.assign(adj_info, minibatch.adj)
    val_adj_info = tf.assign(adj_info, minibatch.test_adj)
    for epoch in range(FLAGS.epochs):
        minibatch.shuffle()

        iter = 0
        print('Epoch: %04d' % (epoch + 1))
        epoch_val_costs.append(0)
        while not minibatch.end():
            # Construct feed dictionary
            feed_dict, labels, batch = minibatch.next_minibatch_feed_dict()
            feed_dict.update({placeholders['dropout']: FLAGS.dropout})

            t = time.time()
            # Training step
            outs = sess.run([merged, model.opt_op, model.loss, model.preds],
                            feed_dict=feed_dict)
            train_cost = outs[2]

            if iter % FLAGS.validate_iter == 0:
                # Validation
                sess.run(val_adj_info.op)
                if FLAGS.validate_batch_size == -1:
                    val_cost, acc, prec, rec, f1_score, conf_mat = incremental_evaluate(
                        sess, model, minibatch, FLAGS.batch_size)
                else:
                    val_cost, acc, prec, rec, f1_score, conf_mat, fpr, tpr, thresholds = evaluate(
                        sess, model, minibatch, FLAGS.validate_batch_size)
                sess.run(train_adj_info.op)
                epoch_val_costs[-1] += val_cost

            if total_steps % FLAGS.print_every == 0:
                summary_writer.add_summary(outs[0], total_steps)

            # Print results
            avg_time = (avg_time * total_steps + time.time() -
                        t) / (total_steps + 1)

            if total_steps % FLAGS.print_every == 0:
                train_acc, train_prec, train_rec, train_f1_score, train_conf_mat, fpr, tpr, thresholds = eval(
                    labels, outs[-1])
                print("Iter:", '%04d' % iter, "train_loss=",
                      "{:.5f}".format(train_cost), "train_accuracy=",
                      "{:.5f}".format(train_acc), "train_precision=",
                      "{:.5f}".format(train_prec), "train_recall=",
                      "{:.5f}".format(train_rec), "train_f1_score=",
                      "{:.5f}".format(train_f1_score))
                with open(results_folder + "/train_stats.txt", "a") as fp:
                    fp.write(
                        "Iter:{:d} loss={:.5f} acc={:.5f} prec={:.5f} rec={:.5f} f1={:.5f} tp={:d} fp={:d} fn={:d} tn={:d}\n"
                        .format(iter, train_cost, train_acc, train_prec,
                                train_rec, train_f1_score,
                                train_conf_mat[0][0], train_conf_mat[0][1],
                                train_conf_mat[1][0], train_conf_mat[1][1]))

            iter += 1
            total_steps += 1

            if total_steps > int(FLAGS.max_total_steps):
                break

        if total_steps > int(FLAGS.max_total_steps):
            break

    print("Optimization Finished!")
    sess.run(val_adj_info.op)
    val_cost, val_acc, val_prec, val_rec, val_f1_score, val_conf_mat = incremental_evaluate(
        sess, model, minibatch, FLAGS.batch_size)
    print("Full validation stats:", "val_cost=", "{:.5f}".format(val_cost),
          "val_acc=", "{:.5f}".format(val_acc), "val_prec=",
          "{:.5f}".format(val_prec), "val_rec=", "{:.5f}".format(val_rec),
          "val_f1_score=", "{:.5f}".format(val_f1_score), "val_conf_mat=",
          val_conf_mat)
    with open(results_folder + "/val_stats.txt", "a") as fp:
        fp.write(
            "loss={:.5f} acc={:.5f} prec={:.5f} rec={:.5f} f1={:.5f} tp={:d} fp={:d} fn={:d} tn={:d} time=={:s}\n"
            .format(val_cost, val_acc, val_prec, val_rec, val_f1_score,
                    val_conf_mat[0][0], val_conf_mat[0][1], val_conf_mat[1][0],
                    val_conf_mat[1][1], current_time))
    print("Writing test set stats to file")
    test_cost, test_acc, test_prec, test_rec, test_f1_score, test_conf_mat = incremental_evaluate(
        sess, model, minibatch, FLAGS.batch_size, test=True)
    print("Full test stats:", "test_cost=", "{:.5f}".format(test_cost),
          "test_acc=", "{:.5f}".format(test_acc), "test_prec=",
          "{:.5f}".format(test_prec), "test_rec=", "{:.5f}".format(test_rec),
          "test_f1_score=", "{:.5f}".format(test_f1_score), "test_conf_mat=",
          test_conf_mat)
    with open(results_folder + "/test_stats.txt", "a") as fp:
        fp.write(
            "loss={:.5f} acc={:.5f} prec={:.5f} rec={:.5f} f1={:.5f} tp={:d} fp={:d} fn={:d} tn={:d} time=={:s}\n"
            .format(test_cost, test_acc, test_prec, test_rec, test_f1_score,
                    test_conf_mat[0][0], test_conf_mat[0][1],
                    test_conf_mat[1][0], test_conf_mat[1][1], current_time))
    compteur += 1

for node in G_test.nodes():

    dicoIdMap[node] = compteur
    compteur += 1
id_map = dicoIdMap

minibatch = EdgeMinibatchIterator(G,
                                  edgelist,
                                  test_edgelist,
                                  id_map,
                                  batch_size=100,
                                  max_degree=3)

sampler = UniformNeighborSampler(minibatch.adj)

layer_infos = [
    SAGEInfo("node", sampler, 3, 128),
    SAGEInfo("node", sampler, 3, 256)
]

features = []

maximum = -1
"""
for node in G.nodes: 

    text_embed = dicoEmbT[str(node)]
    maximum = max(tf.convert_to_tensor(text_embed.detach().numpy()).shape[1], maximum)
print("MAXIMUM", maximum)
예제 #5
0
    def _create_model(self, num_classes, placeholders, features, adj_info,
                      minibatch):
        if self.model_name == 'graphsage_mean':
            # Create model
            sampler = UniformNeighborSampler(adj_info)
            if self.samples_3 != 0:
                layer_infos = [
                    SAGEInfo("node", sampler, self.samples_1, self.dim_1),
                    SAGEInfo("node", sampler, self.samples_2, self.dim_2),
                    SAGEInfo("node", sampler, self.samples_3, self.dim_2)
                ]
            elif self.samples_2 != 0:
                layer_infos = [
                    SAGEInfo("node", sampler, self.samples_1, self.dim_1),
                    SAGEInfo("node", sampler, self.samples_2, self.dim_2)
                ]
            else:
                layer_infos = [
                    SAGEInfo("node", sampler, self.samples_1, self.dim_1)
                ]

            model = SupervisedGraphsage(num_classes,
                                        placeholders,
                                        features,
                                        adj_info,
                                        minibatch.deg,
                                        layer_infos,
                                        weight_decay=self.weight_decay,
                                        learning_rate=self.learning_rate,
                                        model_size=self.model_size,
                                        sigmoid_loss=self.sigmoid,
                                        identity_dim=self.identity_dim,
                                        logging=True)
        elif self.model_name == 'gcn':
            # Create model
            sampler = UniformNeighborSampler(adj_info)
            layer_infos = [
                SAGEInfo("node", sampler, self.samples_1, 2 * self.dim_1),
                SAGEInfo("node", sampler, self.samples_2, 2 * self.dim_2)
            ]

            model = SupervisedGraphsage(num_classes,
                                        placeholders,
                                        features,
                                        adj_info,
                                        minibatch.deg,
                                        layer_infos=layer_infos,
                                        weight_decay=self.weight_decay,
                                        learning_rate=self.learning_rate,
                                        aggregator_type="gcn",
                                        model_size=self.model_size,
                                        concat=False,
                                        sigmoid_loss=self.sigmoid,
                                        identity_dim=self.identity_dim,
                                        logging=True)
        elif self.model_name == 'graphsage_seq':
            sampler = UniformNeighborSampler(adj_info)
            layer_infos = [
                SAGEInfo("node", sampler, self.samples_1, self.dim_1),
                SAGEInfo("node", sampler, self.samples_2, self.dim_2)
            ]

            model = SupervisedGraphsage(num_classes,
                                        placeholders,
                                        features,
                                        adj_info,
                                        minibatch.deg,
                                        layer_infos=layer_infos,
                                        weight_decay=self.weight_decay,
                                        learning_rate=self.learning_rate,
                                        aggregator_type="seq",
                                        model_size=self.model_size,
                                        sigmoid_loss=self.sigmoid,
                                        identity_dim=self.identity_dim,
                                        logging=True)
        elif self.model_name == 'graphsage_maxpool':
            sampler = UniformNeighborSampler(adj_info)
            layer_infos = [
                SAGEInfo("node", sampler, self.samples_1, self.dim_1),
                SAGEInfo("node", sampler, self.samples_2, self.dim_2)
            ]

            model = SupervisedGraphsage(num_classes,
                                        placeholders,
                                        features,
                                        adj_info,
                                        minibatch.deg,
                                        layer_infos=layer_infos,
                                        weight_decay=self.weight_decay,
                                        learning_rate=self.learning_rate,
                                        aggregator_type="maxpool",
                                        model_size=self.model_size,
                                        sigmoid_loss=self.sigmoid,
                                        identity_dim=self.identity_dim,
                                        logging=True)
        elif self.model_name == 'graphsage_meanpool':
            sampler = UniformNeighborSampler(adj_info)
            layer_infos = [
                SAGEInfo("node", sampler, self.samples_1, self.dim_1),
                SAGEInfo("node", sampler, self.samples_2, self.dim_2)
            ]

            model = SupervisedGraphsage(num_classes,
                                        placeholders,
                                        features,
                                        adj_info,
                                        minibatch.deg,
                                        layer_infos=layer_infos,
                                        weight_decay=self.weight_decay,
                                        learning_rate=self.learning_rate,
                                        aggregator_type="meanpool",
                                        model_size=self.model_size,
                                        sigmoid_loss=self.sigmoid,
                                        identity_dim=self.identity_dim,
                                        logging=True)
        else:
            raise Exception('Error: model name unrecognized.')
        return model
예제 #6
0
파일: model.py 프로젝트: yyht/Graph2Seq
    def optimized_gcn_encode(self):
        # [node_size, hidden_layer_dim]
        embedded_node_rep = self.encode_node_feature(self.word_embeddings, self.feature_info)

        fw_sampler = UniformNeighborSampler(self.fw_adj_info)
        bw_sampler = UniformNeighborSampler(self.bw_adj_info)
        nodes = tf.reshape(self.batch_nodes, [-1, ])

        # batch_size = tf.shape(nodes)[0]

        # the fw_hidden and bw_hidden is the initial node embedding
        # [node_size, dim_size]
        fw_hidden = tf.nn.embedding_lookup(embedded_node_rep, nodes)
        bw_hidden = tf.nn.embedding_lookup(embedded_node_rep, nodes)

        # [node_size, adj_size]
        fw_sampled_neighbors = fw_sampler((nodes, self.sample_size_per_layer))
        bw_sampled_neighbors = bw_sampler((nodes, self.sample_size_per_layer))

        fw_sampled_neighbors_len = tf.constant(0)
        bw_sampled_neighbors_len = tf.constant(0)

        # sample
        for layer in range(self.sample_layer_size):
            if layer == 0:
                dim_mul = 1
            else:
                dim_mul = 2

            if layer > 6:
                fw_aggregator = self.fw_aggregators[6]
            else:
                fw_aggregator = MeanAggregator(dim_mul * self.hidden_layer_dim, self.hidden_layer_dim, concat=self.concat, mode=self.mode)
                self.fw_aggregators.append(fw_aggregator)

            # [node_size, adj_size, word_embedding_dim]
            if layer == 0:
                neigh_vec_hidden = tf.nn.embedding_lookup(embedded_node_rep, fw_sampled_neighbors)

                # compute the neighbor size
                tmp_sum = tf.reduce_sum(tf.nn.relu(neigh_vec_hidden), axis=2)
                tmp_mask = tf.sign(tmp_sum)
                fw_sampled_neighbors_len = tf.reduce_sum(tmp_mask, axis=1)

            else:
                neigh_vec_hidden = tf.nn.embedding_lookup(
                    tf.concat([fw_hidden, tf.zeros([1, dim_mul * self.hidden_layer_dim])], 0), fw_sampled_neighbors)

            fw_hidden = fw_aggregator((fw_hidden, neigh_vec_hidden, fw_sampled_neighbors_len))


            if self.graph_encode_direction == "bi":
                if layer > 6:
                    bw_aggregator = self.bw_aggregators[6]
                else:
                    bw_aggregator = MeanAggregator(dim_mul * self.hidden_layer_dim, self.hidden_layer_dim, concat=self.concat, mode=self.mode)
                    self.bw_aggregators.append(bw_aggregator)

                if layer == 0:
                    neigh_vec_hidden = tf.nn.embedding_lookup(embedded_node_rep, bw_sampled_neighbors)

                    # compute the neighbor size
                    tmp_sum = tf.reduce_sum(tf.nn.relu(neigh_vec_hidden), axis=2)
                    tmp_mask = tf.sign(tmp_sum)
                    bw_sampled_neighbors_len = tf.reduce_sum(tmp_mask, axis=1)

                else:
                    neigh_vec_hidden = tf.nn.embedding_lookup(
                        tf.concat([bw_hidden, tf.zeros([1, dim_mul * self.hidden_layer_dim])], 0), bw_sampled_neighbors)

                bw_hidden = bw_aggregator((bw_hidden, neigh_vec_hidden, bw_sampled_neighbors_len))

        # hidden stores the representation for all nodes
        fw_hidden = tf.reshape(fw_hidden, [-1, self.single_graph_nodes_size, 2 * self.hidden_layer_dim])
        if self.graph_encode_direction == "bi":
            bw_hidden = tf.reshape(bw_hidden, [-1, self.single_graph_nodes_size, 2 * self.hidden_layer_dim])
            hidden = tf.concat([fw_hidden, bw_hidden], axis=2)
        else:
            hidden = fw_hidden

        hidden = tf.nn.relu(hidden)

        pooled = tf.reduce_max(hidden, 1)
        if self.graph_encode_direction == "bi":
            graph_embedding = tf.reshape(pooled, [-1, 4 * self.hidden_layer_dim])
        else:
            graph_embedding = tf.reshape(pooled, [-1, 2 * self.hidden_layer_dim])

        graph_embedding = LSTMStateTuple(c=graph_embedding, h=graph_embedding)

        # shape of hidden: [batch_size, single_graph_nodes_size, 4 * hidden_layer_dim]
        # shape of graph_embedding: ([batch_size, 4 * hidden_layer_dim], [batch_size, 4 * hidden_layer_dim])
        return hidden, graph_embedding
예제 #7
0
def train(train_data, test_data=None):
    features, label_map, \
        train_nodes, valid_nodes, test_nodes, \
        train_adj, train_weight_adj, train_column_adj, \
        test_adj, test_weight_adj, test_column_adj = train_data

    # if isinstance(list(class_map.values())[0], list):
    #     num_classes = len(list(class_map.values())[0])
    # else:
    #     num_classes = len(set(class_map.values()))

    num_classes = label_map.shape[1]
    feats_dim = features.shape[1]

    # 插入0行好像没什么用啊?
    if not features is None:
        # pad with dummy zero vector
        features = np.vstack([features, np.zeros((feats_dim, ))])
    # 不晓得为啥要variable(constant(), trainable=False), 很奇怪
    features_info = tf.Variable(tf.constant(features, dtype=tf.float32),
                                trainable=False)

    #context_pairs = train_data[3] if FLAGS.random_context else None
    placeholders = construct_placeholders(num_classes, feats_dim)
    minibatch = NodeMinibatchIterator(
        placeholders,
        #   features,
        #   id_map,
        #   weight_map,
        label_map,
        #   weight_dict,
        supervised_info=[train_nodes, valid_nodes, test_nodes],
        batch_size=FLAGS.batch_size,
        max_degree=FLAGS.max_degree)

    # 注意!是placeholder, 且是全量的
    # TODO shape 还有数据信息, (, train_adj.shape)
    adj_info_ph = tf.placeholder(tf.int32, shape=train_adj.shape)
    weight_adj_info_ph = tf.placeholder(tf.float32,
                                        shape=train_weight_adj.shape)
    column_adj_info_ph = tf.placeholder(tf.int32, shape=train_column_adj.shape)

    adj_info = tf.Variable(adj_info_ph, trainable=False, name="adj_info")
    weight_adj_info = tf.Variable(weight_adj_info_ph,
                                  trainable=False,
                                  name='weight_adj_info')
    column_adj_info = tf.Variable(column_adj_info_ph,
                                  trainable=False,
                                  name='column_adj_info')

    # 没有被完全赋值,只是赋值操作
    train_adj_info = tf.assign(adj_info, train_adj)
    val_adj_info = tf.assign(adj_info, test_adj)

    train_weight_adj_info = tf.assign(weight_adj_info, train_weight_adj)
    val_weight_adj_info = tf.assign(weight_adj_info, test_weight_adj)

    train_column_adj_info = tf.assign(column_adj_info, train_column_adj)
    val_column_adj_info = tf.assign(column_adj_info, test_column_adj)

    # 采样
    # TODO  features 数据还是从这里进去了
    # TODO 要拿出来啊啊啊啊啊
    sampler = UniformNeighborSampler(features_info, adj_info, weight_adj_info,
                                     column_adj_info)

    # === build model ===
    if FLAGS.model == 'graphsage_mean':
        # Create model
        sampler = UniformNeighborSampler(adj_info, weight_adj_info,
                                         column_adj_info)
        # 16, 8
        if FLAGS.samples_3 != 0:
            layer_infos = [
                SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2),
                SAGEInfo("node", sampler, FLAGS.samples_3, FLAGS.dim_2)
            ]
        elif FLAGS.samples_2 != 0:
            layer_infos = [
                SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
            ]
        else:
            layer_infos = [
                SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1)
            ]

        model = SupervisedGraphsage(num_classes,
                                    placeholders,
                                    features,
                                    adj_info,
                                    minibatch.deg,
                                    layer_infos,
                                    concat=True,
                                    model_size=FLAGS.model_size,
                                    sigmoid_loss=FLAGS.sigmoid,
                                    identity_dim=FLAGS.identity_dim,
                                    logging=True)

    elif FLAGS.model == 'gcn':
        # Create model
        sampler = UniformNeighborSampler(adj_info, weight_adj_info,
                                         column_adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, 2 * FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, 2 * FLAGS.dim_2)
        ]

        model = SupervisedGraphsage(num_classes,
                                    placeholders,
                                    features,
                                    adj_info,
                                    minibatch.deg,
                                    layer_infos=layer_infos,
                                    aggregator_type="gcn",
                                    model_size=FLAGS.model_size,
                                    concat=False,
                                    sigmoid_loss=FLAGS.sigmoid,
                                    identity_dim=FLAGS.identity_dim,
                                    logging=True)

    elif FLAGS.model == 'geniepath':
        sampler = UniformNeighborSampler(adj_info, weight_adj_info,
                                         column_adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SupervisedGraphsage(num_classes,
                                    placeholders,
                                    features,
                                    adj_info,
                                    minibatch.deg,
                                    layer_infos=layer_infos,
                                    aggregator_type="geniepath",
                                    model_size=FLAGS.model_size,
                                    concat=False,
                                    sigmoid_loss=FLAGS.sigmoid,
                                    identity_dim=FLAGS.identity_dim,
                                    logging=True)

    elif FLAGS.model == 'cross':
        # Create model
        # if FLAGS.samples_3 != 0:
        #     layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
        #                    SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2),
        #                    SAGEInfo("node", sampler, FLAGS.samples_3, FLAGS.dim_2)]
        # elif FLAGS.samples_2 != 0:
        #     layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
        #                    SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)]
        # else:
        #     layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1)]
        layer_infos = [
            SAGEInfo("node", FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", FLAGS.samples_2, FLAGS.dim_2)
        ]
        model = SupervisedGraphsage(
            placeholders,
            feats_dim,
            num_classes,
            sampler,
            # features,
            # adj_info, # variable
            # minibatch.deg,
            layer_infos=layer_infos,
            aggregator_type='cross',  # 多了这一句
            concat=True,
            model_size=FLAGS.model_size,
            sigmoid_loss=FLAGS.sigmoid,
            identity_dim=FLAGS.identity_dim,
            logging=True)

    else:
        raise Exception('Error: model name unrecognized.')

    config = tf.ConfigProto(log_device_placement=FLAGS.log_device_placement)
    config.gpu_options.allow_growth = True
    # config.gpu_options.per_process_gpu_memory_fraction = GPU_MEM_FRACTION
    config.allow_soft_placement = True

    # Initialize session
    sess = tf.Session(config=config)
    # merged = tf.summary.merge_all()
    # summary_writer = tf.summary.FileWriter(log_dir(), sess.graph)

    # Init variables
    sess.run(tf.global_variables_initializer(),
             feed_dict={
                 adj_info_ph: train_adj,
                 weight_adj_info_ph: train_weight_adj,
                 column_adj_info_ph: train_column_adj
             })

    # === Train model ===
    total_steps = 0
    avg_time = 0.0
    epoch_val_costs = []

    for epoch in range(FLAGS.train_epochs):
        minibatch.shuffle()

        iter = 0
        print('\n### Epoch: %04d ###' % (epoch + 1))
        epoch_val_costs.append(0)
        while not minibatch.end():
            # Construct feed dictionary
            feed_dict, labels = minibatch.next_minibatch_feed_dict(
            )  # 每一次都有全量的feet 进来
            feed_dict.update({placeholders['dropout']: FLAGS.dropout})

            # t = time.time()
            outs = sess.run([model.opt_op, model.loss, model.preds],
                            feed_dict=feed_dict)
            # outs = sess.run([merged, model.opt_op, model.loss, model.preds],
            #                 feed_dict=feed_dict)
            # outs = sess.run([merged, model.loss, model.preds],feed_dict=feed_dict)
            # train_cost = outs[2]

            # if iter % FLAGS.validate_iter == 0:
            #     # Validation
            #     # do the assign operation
            #     sess.run([val_adj_info.op, val_weight_adj_info.op, val_column_adj_info.op])

            #     # 如果有设置采样数量的话
            #     if FLAGS.validate_batch_size == -1:
            #         val_cost, val_f1_mic, val_f1_mac,report, duration, _ = incremental_evaluate(
            #             sess, model, minibatch, FLAGS.batch_size)
            #     else:
            #         val_cost, val_f1_mic, val_f1_mac, duration = evaluate(
            #             sess, model, minibatch, FLAGS.validate_batch_size)

            #     sess.run([train_adj_info.op, train_weight_adj_info.op, train_column_adj_info.op])

            #     epoch_val_costs[-1] += val_cost

            # if total_steps % FLAGS.print_every == 0:
            #     summary_writer.add_summary(outs[0], total_steps)

            # Print results
            # avg_time = (avg_time * total_steps + time.time() - t) / (total_steps + 1)
            # print("train_time=", "{:.5f}".format(avg_time))

            # if total_steps % FLAGS.print_every == 0:
            # train_f1_mic, train_f1_mac = calc_f1(labels, outs[-1])
            # train_accuracy = calc_acc(labels,outs[-1])
            # report = classification_report(labels,outs[-1])
            # print("Iter:", '%04d' % iter,
            #       "train_loss=", "{:.5f}".format(train_cost),
            #       "train_f1_mic=", "{:.5f}".format(train_f1_mic),
            #       "train_f1_mac=", "{:.5f}".format(train_f1_mac),
            #       "val_loss=", "{:.5f}".format(val_cost),
            #       "val_f1_mic=", "{:.5f}".format(val_f1_mic),
            #       "val_f1_mac=", "{:.5f}".format(val_f1_mac),
            #       "time=", "{:.5f}".format(avg_time))
            #print(report)

            iter += 1
            total_steps += 1

            if total_steps > FLAGS.max_total_steps:
                break

        # when each epoch ends
        # show the F1 report
        if epoch % 1 == 0:

            # sess.run([val_adj_info.op, val_weight_adj_info.op, val_column_adj_info.op])
            sess.run([val_adj_info, val_weight_adj_info, val_column_adj_info])

            # val_cost, val_f1_mic, val_f1_mac, report, duration = incremental_evaluate(
            #     sess, model, minibatch, FLAGS.batch_size)
            # area = my_incremental_evaluate(
            #     sess, model, minibatch, FLAGS.batch_size)
            # # precision, recall, thresholds = precision_recall_curve(
            # #     val_labels[:, 1], val_preds[:, 1])
            # # area2 = auc(recall, precision)

            # print("Full validation stats:",
            #       "loss=", "{:.5f}".format(val_cost),
            #       "f1_micro=", "{:.5f}".format(val_f1_mic),
            #       "f1_macro=", "{:.5f}".format(val_f1_mac),
            #       "time=", "{:.5f}".format(duration))

            # print(report)
            # print('AUC',area)

            test_cost, test_f1_mic, test_f1_mac, report, duration = incremental_evaluate(
                sess, model, minibatch, FLAGS.batch_size, test=True)
            area = my_incremental_evaluate(sess,
                                           model,
                                           minibatch,
                                           FLAGS.batch_size,
                                           test=True)
            # precision, recall, thresholds = precision_recall_curve(
            #     test_labels[:, 1], test_preds[:, 1])
            # area2 = auc(recall, precision)

            print("Full Test stats:", "loss=", "{:.5f}".format(test_cost),
                  "f1_micro=", "{:.5f}".format(test_f1_mic), "f1_macro=",
                  "{:.5f}".format(test_f1_mac), "time=",
                  "{:.5f}".format(duration))
            print(report)
            print('AUC', area)

            # once acu > 0.82, save the model
            if area > 0.83:
                model.save(sess)
                print('AUC gotcha! model saved.')

                # np.save('../data/'+FLAGS.model+'aggr'+'_precision',precision)
                # np.save('../data/'+FLAGS.model+'aggr'+'_recall',recall)

        # 应该设置下early stopping
        if total_steps > FLAGS.max_total_steps:
            break

    # model.save(sess)
    print("Optimization Finished!")

    sess.run([val_adj_info.op, val_weight_adj_info.op, val_column_adj_info.op])
    # val_cost, val_f1_mic, val_f1_mac, report, duration, area = incremental_evaluate(
    #     sess, model, minibatch, FLAGS.batch_size)
    # area = my_incremental_evaluate(
    #     sess, model, minibatch, FLAGS.batch_size)
    # precision, recall, thresholds = precision_recall_curve(
    #     val_labels[:, 1], val_preds[:, 1])
    # area = auc(recall, precision)

    # np.save('../data/val_preds.npy', val_preds)
    # np.save('../data/val_labels.npy', val_labels)
    # np.save('../data/val_cost.npy', v_cost)

    # print("Full validation stats:",
    #       "loss=", "{:.5f}".format(val_cost),
    #       "f1_micro=", "{:.5f}".format(val_f1_mic),
    #       "f1_macro=", "{:.5f}".format(val_f1_mac),
    #       "time=", "{:.5f}".format(duration))
    # print(report)
    # print('AUC', area)

    # with open(log_dir() + "val_stats.txt", "w") as fp:
    #     fp.write("loss={:.5f} f1_micro={:.5f} f1_macro={:.5f} time={:.5f}".
    #              format(val_cost, val_f1_mic, val_f1_mac, duration))

    test_cost, test_f1_mic, test_f1_mac, report, duration = incremental_evaluate(
        sess, model, minibatch, FLAGS.batch_size, test=True)
    area = my_incremental_evaluate(sess,
                                   model,
                                   minibatch,
                                   FLAGS.batch_size,
                                   test=True)
    # precision, recall, thresholds = precision_recall_curve(
    #     test_labels[:, 1], test_preds[:, 1])
    # area = auc(recall, precision)

    # np.save('../data/test_preds.npy', test_preds)
    # np.save('../data/test_labels.npy', test_labels)
    # np.save('../data/test_cost.npy', t_cost) # prevent from override

    print("Full Test stats:", "loss=", "{:.5f}".format(test_cost), "f1_micro=",
          "{:.5f}".format(test_f1_mic), "f1_macro=",
          "{:.5f}".format(test_f1_mac), "time=", "{:.5f}".format(duration))
    print(report)
    print('AUC:', area)

    with open(log_dir() + "test_stats.txt", "w") as fp:
        fp.write("loss={:.5f} f1_micro={:.5f} f1_macro={:.5f}".format(
            test_cost, test_f1_mic, test_f1_mac))
예제 #8
0
def train(train_data, test_data=None):

    # adj, features, y_train, y_val, y_test, train_mask, val_mask, test_mask

    G = train_data[0]
    features = train_data[1]
    if not features is None:
        # pad with dummy zero vector
        features = np.vstack([features, np.zeros((features.shape[1], ))])

    placeholders = construct_placeholders()
    minibatch = EdgeMinibatchIterator(G,
                                      placeholders,
                                      batch_size=FLAGS.batch_size,
                                      max_degree=FLAGS.max_degree,
                                      num_neg_samples=FLAGS.neg_sample_size)
    adj_info_ph = tf.placeholder(tf.int32, shape=minibatch.adj.shape)
    adj_info = tf.Variable(adj_info_ph, trainable=False, name="adj_info")

    if FLAGS.model == 'graphsage_mean':
        # Create model
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   logging=True)
    elif FLAGS.model == 'gcn':
        # Create model
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, 2 * FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, 2 * FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   aggregator_type="gcn",
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   concat=False,
                                   logging=True)

    elif FLAGS.model == 'graphsage_seq':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   identity_dim=FLAGS.identity_dim,
                                   aggregator_type="seq",
                                   model_size=FLAGS.model_size,
                                   logging=True)

    elif FLAGS.model == 'graphsage_maxpool':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   aggregator_type="maxpool",
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   logging=True)
    elif FLAGS.model == 'graphsage_meanpool':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   aggregator_type="meanpool",
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   logging=True)
    else:
        raise Exception('Error: model name unrecognized.')

    config = tf.ConfigProto(log_device_placement=FLAGS.log_device_placement)
    config.gpu_options.allow_growth = True
    #config.gpu_options.per_process_gpu_memory_fraction = GPU_MEM_FRACTION
    config.allow_soft_placement = True

    # Initialize session
    sess = tf.Session(config=config)
    merged = tf.summary.merge_all()
    summary_writer = tf.summary.FileWriter(log_dir(), sess.graph)

    # Init variables
    sess.run(tf.global_variables_initializer(),
             feed_dict={adj_info_ph: minibatch.adj})

    # Train model

    train_shadow_mrr = None
    shadow_mrr = None

    total_steps = 0
    avg_time = 0.0
    epoch_val_costs = []

    train_adj_info = tf.assign(adj_info, minibatch.adj)
    for epoch in range(FLAGS.epochs):
        minibatch.shuffle()

        iter = 0
        print('Epoch: %04d' % (epoch + 1))
        epoch_val_costs.append(0)
        while not minibatch.end():
            # Construct feed dictionary
            feed_dict = minibatch.next_minibatch_feed_dict()
            feed_dict.update({placeholders['dropout']: FLAGS.dropout})

            t = time.time()
            # Training step
            outs = sess.run([
                merged, model.opt_op, model.loss, model.ranks, model.aff_all,
                model.mrr, model.outputs1
            ],
                            feed_dict=feed_dict)
            train_cost = outs[2]
            train_mrr = outs[5]
            if train_shadow_mrr is None:
                train_shadow_mrr = train_mrr  #
            else:
                train_shadow_mrr -= (1 - 0.99) * (train_shadow_mrr - train_mrr)

            # Print results
            avg_time = (avg_time * total_steps + time.time() -
                        t) / (total_steps + 1)

            if total_steps % FLAGS.print_every == 0:
                print("Iter:", '%04d' % iter, "train_loss=",
                      "{:.5f}".format(train_cost), "train_mrr=",
                      "{:.5f}".format(train_mrr), "train_mrr_ema=",
                      "{:.5f}".format(train_shadow_mrr))

            iter += 1
            total_steps += 1

            if total_steps > FLAGS.max_total_steps:
                break

        if total_steps > FLAGS.max_total_steps:
            break

    print("Optimization Finished!")
    if FLAGS.save_embeddings:
        save_val_embeddings(sess, model, minibatch, FLAGS.validate_batch_size,
                            log_dir())
예제 #9
0
    def gcn_encode(self,
                   batch_nodes,
                   embedded_node_rep,
                   fw_adj_info,
                   bw_adj_info,
                   input_node_dim,
                   output_node_dim,
                   fw_aggregators,
                   bw_aggregators,
                   window_size,
                   layer_size,
                   scope,
                   agg_type,
                   sample_size_per_layer,
                   keep_inter_state=False):
        with tf.variable_scope(scope):
            single_graph_nodes_size = tf.shape(batch_nodes)[1]
            # ============ encode graph structure ==========
            fw_sampler = UniformNeighborSampler(fw_adj_info)
            bw_sampler = UniformNeighborSampler(bw_adj_info)
            nodes = tf.reshape(batch_nodes, [
                -1,
            ])

            # the fw_hidden and bw_hidden is the initial node embedding
            # [node_size, dim_size]
            fw_hidden = tf.nn.embedding_lookup(embedded_node_rep, nodes)
            bw_hidden = tf.nn.embedding_lookup(embedded_node_rep, nodes)

            # [node_size, adj_size]
            fw_sampled_neighbors = fw_sampler((nodes, sample_size_per_layer))
            bw_sampled_neighbors = bw_sampler((nodes, sample_size_per_layer))

            inter_fw_hiddens = []
            inter_bw_hiddens = []
            inter_dims = []

            if scope == "first_gcn":
                self.watch["node_1_rep_in_first_gcn"] = []

            fw_hidden_dim = input_node_dim
            # layer is the index of convolution and hop is used to combine information
            for layer in range(layer_size):
                self.watch["node_1_rep_in_first_gcn"].append(fw_hidden)

                if len(fw_aggregators) <= layer:
                    fw_aggregators.append([])
                if len(bw_aggregators) <= layer:
                    bw_aggregators.append([])
                for hop in range(window_size):
                    if hop > 6:
                        fw_aggregator = fw_aggregators[layer][6]
                    elif len(fw_aggregators[layer]) > hop:
                        fw_aggregator = fw_aggregators[layer][hop]
                    else:
                        if agg_type == "GCN":
                            fw_aggregator = GCNAggregator(fw_hidden_dim,
                                                          output_node_dim,
                                                          concat=self.concat,
                                                          dropout=self.dropout,
                                                          mode=self.mode)
                        elif agg_type == "mean_pooling":
                            fw_aggregator = MeanAggregator(
                                fw_hidden_dim,
                                output_node_dim,
                                concat=self.concat,
                                dropout=self.dropout,
                                if_use_high_way=self.with_gcn_highway,
                                mode=self.mode)
                        elif agg_type == "max_pooling":
                            fw_aggregator = MaxPoolingAggregator(
                                fw_hidden_dim,
                                output_node_dim,
                                concat=self.concat,
                                dropout=self.dropout,
                                mode=self.mode)
                        elif agg_type == "lstm":
                            fw_aggregator = SeqAggregator(fw_hidden_dim,
                                                          output_node_dim,
                                                          concat=self.concat,
                                                          dropout=self.dropout,
                                                          mode=self.mode)
                        elif agg_type == "att":
                            fw_aggregator = AttentionAggregator(
                                fw_hidden_dim,
                                output_node_dim,
                                concat=self.concat,
                                dropout=self.dropout,
                                mode=self.mode)

                        fw_aggregators[layer].append(fw_aggregator)

                    # [node_size, adj_size, word_embedding_dim]
                    if layer == 0 and hop == 0:
                        neigh_vec_hidden = tf.nn.embedding_lookup(
                            embedded_node_rep, fw_sampled_neighbors)
                    else:
                        neigh_vec_hidden = tf.nn.embedding_lookup(
                            tf.concat(
                                [fw_hidden,
                                 tf.zeros([1, fw_hidden_dim])], 0),
                            fw_sampled_neighbors)

                    # if self.with_gcn_highway:
                    #     # we try to forget something when introducing the neighbor information
                    #     with tf.variable_scope("fw_hidden_highway"):
                    #         fw_hidden = multi_highway_layer(fw_hidden, fw_hidden_dim, options['highway_layer_num'])

                    bw_hidden_dim = fw_hidden_dim

                    fw_hidden, fw_hidden_dim = fw_aggregator(
                        (fw_hidden, neigh_vec_hidden))

                    if keep_inter_state:
                        inter_fw_hiddens.append(fw_hidden)
                        inter_dims.append(fw_hidden_dim)

                    if self.graph_encode_direction == "bi":
                        if hop > 6:
                            bw_aggregator = bw_aggregators[layer][6]
                        elif len(bw_aggregators[layer]) > hop:
                            bw_aggregator = bw_aggregators[layer][hop]
                        else:
                            if agg_type == "GCN":
                                bw_aggregator = GCNAggregator(
                                    bw_hidden_dim,
                                    output_node_dim,
                                    concat=self.concat,
                                    dropout=self.dropout,
                                    mode=self.mode)
                            elif agg_type == "mean_pooling":
                                bw_aggregator = MeanAggregator(
                                    bw_hidden_dim,
                                    output_node_dim,
                                    concat=self.concat,
                                    dropout=self.dropout,
                                    if_use_high_way=self.with_gcn_highway,
                                    mode=self.mode)
                            elif agg_type == "max_pooling":
                                bw_aggregator = MaxPoolingAggregator(
                                    bw_hidden_dim,
                                    output_node_dim,
                                    concat=self.concat,
                                    dropout=self.dropout,
                                    mode=self.mode)
                            elif agg_type == "lstm":
                                bw_aggregator = SeqAggregator(
                                    bw_hidden_dim,
                                    output_node_dim,
                                    concat=self.concat,
                                    dropout=self.dropout,
                                    mode=self.mode)
                            elif agg_type == "att":
                                bw_aggregator = AttentionAggregator(
                                    bw_hidden_dim,
                                    output_node_dim,
                                    concat=self.concat,
                                    mode=self.mode,
                                    dropout=self.dropout)

                            bw_aggregators[layer].append(bw_aggregator)

                        if layer == 0 and hop == 0:
                            neigh_vec_hidden = tf.nn.embedding_lookup(
                                embedded_node_rep, bw_sampled_neighbors)
                        else:
                            neigh_vec_hidden = tf.nn.embedding_lookup(
                                tf.concat(
                                    [bw_hidden,
                                     tf.zeros([1, fw_hidden_dim])], 0),
                                bw_sampled_neighbors)

                        if self.with_gcn_highway:
                            with tf.variable_scope("bw_hidden_highway"):
                                bw_hidden = multi_highway_layer(
                                    bw_hidden, fw_hidden_dim,
                                    options['highway_layer_num'])

                        bw_hidden, bw_hidden_dim = bw_aggregator(
                            (bw_hidden, neigh_vec_hidden))

                        if keep_inter_state:
                            inter_bw_hiddens.append(bw_hidden)

            node_dim = fw_hidden_dim

            # hidden stores the representation for all nodes
            fw_hidden = tf.reshape(fw_hidden,
                                   [-1, single_graph_nodes_size, node_dim])
            if self.graph_encode_direction == "bi":
                bw_hidden = tf.reshape(bw_hidden,
                                       [-1, single_graph_nodes_size, node_dim])
                hidden = tf.concat([fw_hidden, bw_hidden], axis=2)
                graph_dim = 2 * node_dim
            else:
                hidden = fw_hidden
                graph_dim = node_dim

            hidden = tf.nn.relu(hidden)
            max_pooled = tf.reduce_max(hidden, 1)
            mean_pooled = tf.reduce_mean(hidden, 1)
            res = [hidden]

            max_graph_embedding = tf.reshape(max_pooled, [-1, graph_dim])
            mean_graph_embedding = tf.reshape(mean_pooled, [-1, graph_dim])
            res.append(max_graph_embedding)
            res.append(mean_graph_embedding)
            res.append(graph_dim)

            if keep_inter_state:
                inter_node_reps = []
                inter_graph_reps = []
                inter_graph_dims = []
                # process the inter hidden states
                for _ in range(len(inter_fw_hiddens)):
                    inter_fw_hidden = inter_fw_hiddens[_]
                    inter_bw_hidden = inter_bw_hiddens[_]
                    inter_dim = inter_dims[_]
                    inter_fw_hidden = tf.reshape(
                        inter_fw_hidden,
                        [-1, single_graph_nodes_size, inter_dim])

                    if self.graph_encode_direction == "bi":
                        inter_bw_hidden = tf.reshape(
                            inter_bw_hidden,
                            [-1, single_graph_nodes_size, inter_dim])
                        inter_hidden = tf.concat(
                            [inter_fw_hidden, inter_bw_hidden], axis=2)
                        inter_graph_dim = inter_dim * 2
                    else:
                        inter_hidden = inter_fw_hidden
                        inter_graph_dim = inter_dim

                    inter_node_rep = tf.nn.relu(inter_hidden)
                    inter_node_reps.append(inter_node_rep)
                    inter_graph_dims.append(inter_graph_dim)

                    max_pooled_tmp = tf.reduce_max(inter_node_rep, 1)
                    mean_pooled_tmp = tf.reduce_max(inter_node_rep, 1)
                    max_graph_embedding = tf.reshape(max_pooled_tmp,
                                                     [-1, inter_graph_dim])
                    mean_graph_embedding = tf.reshape(mean_pooled_tmp,
                                                      [-1, inter_graph_dim])
                    inter_graph_reps.append(
                        (max_graph_embedding, mean_graph_embedding))

                res.append(inter_node_reps)
                res.append(inter_graph_reps)
                res.append(inter_graph_dims)

            return res
def train(train_data, test_data=None):
    G = train_data[0]
    features = train_data[1]
    id_map = train_data[2]

    if not features is None:
        # pad with dummy zero vector
        features = np.vstack([features, np.zeros((features.shape[1], ))])

    context_pairs = train_data[3] if FLAGS.random_context else None
    placeholders = construct_placeholders()
    minibatch = EdgeMinibatchIterator(G,
                                      id_map,
                                      placeholders,
                                      batch_size=FLAGS.batch_size,
                                      max_degree=FLAGS.max_degree,
                                      num_neg_samples=FLAGS.neg_sample_size,
                                      context_pairs=context_pairs)
    adj_info_ph = tf.placeholder(tf.int32, shape=minibatch.adj.shape)
    adj_info = tf.Variable(adj_info_ph, trainable=False, name="adj_info")

    if FLAGS.model == 'graphsage_mean':
        # Create model
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   logging=True)
    elif FLAGS.model == 'gcn':
        # Create model
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, 2 * FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, 2 * FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   aggregator_type="gcn",
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   concat=False,
                                   logging=True)

    elif FLAGS.model == 'graphsage_seq':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   identity_dim=FLAGS.identity_dim,
                                   aggregator_type="seq",
                                   model_size=FLAGS.model_size,
                                   logging=True)

    elif FLAGS.model == 'graphsage_maxpool':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   aggregator_type="maxpool",
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   logging=True)
    elif FLAGS.model == 'graphsage_meanpool':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   aggregator_type="meanpool",
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   logging=True)

    elif FLAGS.model == 'n2v':
        model = Node2VecModel(
            placeholders,
            features.shape[0],
            minibatch.deg,
            #2x because graphsage uses concat
            nodevec_dim=2 * FLAGS.dim_1,
            lr=FLAGS.learning_rate)
    else:
        raise Exception('Error: model name unrecognized.')

    config = tf.ConfigProto(log_device_placement=FLAGS.log_device_placement)
    config.gpu_options.allow_growth = True
    #config.gpu_options.per_process_gpu_memory_fraction = GPU_MEM_FRACTION
    config.allow_soft_placement = True

    # Initialize session
    sess = tf.Session(config=config)
    merged = tf.summary.merge_all()
    summary_writer = tf.summary.FileWriter(log_dir(), sess.graph)

    # Init variables
    sess.run(tf.global_variables_initializer(),
             feed_dict={adj_info_ph: minibatch.adj})

    # Train model

    train_shadow_mrr = None
    shadow_mrr = None

    total_steps = 0
    avg_time = 0.0
    epoch_val_costs = []

    train_adj_info = tf.assign(adj_info, minibatch.adj)
    val_adj_info = tf.assign(adj_info, minibatch.test_adj)
    for epoch in range(FLAGS.epochs):
        minibatch.shuffle()

        iter = 0
        print('Epoch: %04d' % (epoch + 1))
        epoch_val_costs.append(0)
        while not minibatch.end():
            # Construct feed dictionary
            feed_dict = minibatch.next_minibatch_feed_dict()
            feed_dict.update({placeholders['dropout']: FLAGS.dropout})

            t = time.time()
            # Training step
            outs = sess.run([
                merged, model.opt_op, model.loss, model.ranks, model.aff_all,
                model.mrr, model.outputs1
            ],
                            feed_dict=feed_dict)
            train_cost = outs[2]
            train_mrr = outs[5]
            if train_shadow_mrr is None:
                train_shadow_mrr = train_mrr  #
            else:
                train_shadow_mrr -= (1 - 0.99) * (train_shadow_mrr - train_mrr)

            if iter % FLAGS.validate_iter == 0:
                # Validation
                sess.run(val_adj_info.op)
                val_cost, ranks, val_mrr, duration = evaluate(
                    sess, model, minibatch, size=FLAGS.validate_batch_size)
                sess.run(train_adj_info.op)
                epoch_val_costs[-1] += val_cost
            if shadow_mrr is None:
                shadow_mrr = val_mrr
            else:
                shadow_mrr -= (1 - 0.99) * (shadow_mrr - val_mrr)

            if total_steps % FLAGS.print_every == 0:
                summary_writer.add_summary(outs[0], total_steps)

            # Print results
            avg_time = (avg_time * total_steps + time.time() -
                        t) / (total_steps + 1)

            if total_steps % FLAGS.print_every == 0:
                print(
                    "Iter:",
                    '%04d' % iter,
                    "train_loss=",
                    "{:.5f}".format(train_cost),
                    "train_mrr=",
                    "{:.5f}".format(train_mrr),
                    "train_mrr_ema=",
                    "{:.5f}".format(
                        train_shadow_mrr),  # exponential moving average
                    "val_loss=",
                    "{:.5f}".format(val_cost),
                    "val_mrr=",
                    "{:.5f}".format(val_mrr),
                    "val_mrr_ema=",
                    "{:.5f}".format(shadow_mrr),  # exponential moving average
                    "time=",
                    "{:.5f}".format(avg_time))

            iter += 1
            total_steps += 1

            if total_steps > FLAGS.max_total_steps:
                break

        if total_steps > FLAGS.max_total_steps:
            break

    print("Optimization Finished!")
    if FLAGS.save_embeddings:
        sess.run(val_adj_info.op)

        save_val_embeddings(sess, model, minibatch, FLAGS.validate_batch_size,
                            log_dir())

        if FLAGS.model == "n2v":
            # stopping the gradient for the already trained nodes
            train_ids = tf.constant(
                [[id_map[n]] for n in G.nodes_iter()
                 if not G.node[n]['val'] and not G.node[n]['test']],
                dtype=tf.int32)
            test_ids = tf.constant([[id_map[n]] for n in G.nodes_iter()
                                    if G.node[n]['val'] or G.node[n]['test']],
                                   dtype=tf.int32)
            update_nodes = tf.nn.embedding_lookup(model.context_embeds,
                                                  tf.squeeze(test_ids))
            no_update_nodes = tf.nn.embedding_lookup(model.context_embeds,
                                                     tf.squeeze(train_ids))
            update_nodes = tf.scatter_nd(test_ids, update_nodes,
                                         tf.shape(model.context_embeds))
            no_update_nodes = tf.stop_gradient(
                tf.scatter_nd(train_ids, no_update_nodes,
                              tf.shape(model.context_embeds)))
            model.context_embeds = update_nodes + no_update_nodes
            sess.run(model.context_embeds)

            # run random walks
            from graphsage.utils import run_random_walks
            nodes = [
                n for n in G.nodes_iter()
                if G.node[n]["val"] or G.node[n]["test"]
            ]
            start_time = time.time()
            pairs = run_random_walks(G, nodes, num_walks=50)
            walk_time = time.time() - start_time

            test_minibatch = EdgeMinibatchIterator(
                G,
                id_map,
                placeholders,
                batch_size=FLAGS.batch_size,
                max_degree=FLAGS.max_degree,
                num_neg_samples=FLAGS.neg_sample_size,
                context_pairs=pairs,
                n2v_retrain=True,
                fixed_n2v=True)

            start_time = time.time()
            print("Doing test training for n2v.")
            test_steps = 0
            for epoch in range(FLAGS.n2v_test_epochs):
                test_minibatch.shuffle()
                while not test_minibatch.end():
                    feed_dict = test_minibatch.next_minibatch_feed_dict()
                    feed_dict.update({placeholders['dropout']: FLAGS.dropout})
                    outs = sess.run([
                        model.opt_op, model.loss, model.ranks, model.aff_all,
                        model.mrr, model.outputs1
                    ],
                                    feed_dict=feed_dict)
                    if test_steps % FLAGS.print_every == 0:
                        print("Iter:", '%04d' % test_steps, "train_loss=",
                              "{:.5f}".format(outs[1]), "train_mrr=",
                              "{:.5f}".format(outs[-2]))
                    test_steps += 1
            train_time = time.time() - start_time
            save_val_embeddings(sess,
                                model,
                                minibatch,
                                FLAGS.validate_batch_size,
                                log_dir(),
                                mod="-test")
            print("Total time: ", train_time + walk_time)
            print("Walk time: ", walk_time)
            print("Train time: ", train_time)