def train(train_data):

    G = train_data[0]
    features = train_data[1]
    labels = train_data[2]
    train_nodes = train_data[3]
    test_nodes = train_data[4]
    val_nodes = train_data[5]
    num_classes = 2

    if not features is None:
        # pad with dummy zero vector
        features = np.vstack([features, np.zeros((features.shape[1], ))])

    placeholders = construct_placeholders(num_classes)
    minibatch = NodeMinibatchIterator(G,
                                      placeholders,
                                      labels,
                                      train_nodes,
                                      test_nodes,
                                      val_nodes,
                                      num_classes,
                                      batch_size=FLAGS.batch_size,
                                      max_degree=FLAGS.max_degree)
    adj_info_ph = tf.placeholder(tf.int32, shape=minibatch.adj.shape)
    adj_info = tf.Variable(adj_info_ph, trainable=False, name="adj_info")

    if FLAGS.model == 'graphsage_mean':
        # Create model
        sampler = UniformNeighborSampler(adj_info)
        if FLAGS.samples_3 != 0:
            layer_infos = [
                SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2),
                SAGEInfo("node", sampler, FLAGS.samples_3, FLAGS.dim_2)
            ]
        elif FLAGS.samples_2 != 0:
            layer_infos = [
                SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
            ]
        else:
            layer_infos = [
                SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1)
            ]

    model = SupervisedGraphsage(num_classes,
                                placeholders,
                                features,
                                adj_info,
                                minibatch.deg,
                                layer_infos,
                                model_size=FLAGS.model_size,
                                sigmoid_loss=FLAGS.sigmoid,
                                identity_dim=FLAGS.identity_dim,
                                logging=True)

    config = tf.ConfigProto(log_device_placement=FLAGS.log_device_placement)
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True

    # Initialize session
    sess = tf.Session(config=config)
    merged = tf.summary.merge_all()
    summary_writer = tf.summary.FileWriter(results_folder, sess.graph)

    # Init variables
    sess.run(tf.global_variables_initializer(),
             feed_dict={adj_info_ph: minibatch.adj})

    # Train model
    total_steps = 0
    avg_time = 0.0
    epoch_val_costs = []

    train_adj_info = tf.assign(adj_info, minibatch.adj)
    val_adj_info = tf.assign(adj_info, minibatch.test_adj)
    for epoch in range(FLAGS.epochs):
        minibatch.shuffle()

        iter = 0
        print('Epoch: %04d' % (epoch + 1))
        epoch_val_costs.append(0)
        while not minibatch.end():
            # Construct feed dictionary
            feed_dict, labels, batch = minibatch.next_minibatch_feed_dict()
            feed_dict.update({placeholders['dropout']: FLAGS.dropout})

            t = time.time()
            # Training step
            outs = sess.run([merged, model.opt_op, model.loss, model.preds],
                            feed_dict=feed_dict)
            train_cost = outs[2]

            if iter % FLAGS.validate_iter == 0:
                # Validation
                sess.run(val_adj_info.op)
                if FLAGS.validate_batch_size == -1:
                    val_cost, acc, prec, rec, f1_score, conf_mat = incremental_evaluate(
                        sess, model, minibatch, FLAGS.batch_size)
                else:
                    val_cost, acc, prec, rec, f1_score, conf_mat, fpr, tpr, thresholds = evaluate(
                        sess, model, minibatch, FLAGS.validate_batch_size)
                sess.run(train_adj_info.op)
                epoch_val_costs[-1] += val_cost

            if total_steps % FLAGS.print_every == 0:
                summary_writer.add_summary(outs[0], total_steps)

            # Print results
            avg_time = (avg_time * total_steps + time.time() -
                        t) / (total_steps + 1)

            if total_steps % FLAGS.print_every == 0:
                train_acc, train_prec, train_rec, train_f1_score, train_conf_mat, fpr, tpr, thresholds = eval(
                    labels, outs[-1])
                print("Iter:", '%04d' % iter, "train_loss=",
                      "{:.5f}".format(train_cost), "train_accuracy=",
                      "{:.5f}".format(train_acc), "train_precision=",
                      "{:.5f}".format(train_prec), "train_recall=",
                      "{:.5f}".format(train_rec), "train_f1_score=",
                      "{:.5f}".format(train_f1_score))
                with open(results_folder + "/train_stats.txt", "a") as fp:
                    fp.write(
                        "Iter:{:d} loss={:.5f} acc={:.5f} prec={:.5f} rec={:.5f} f1={:.5f} tp={:d} fp={:d} fn={:d} tn={:d}\n"
                        .format(iter, train_cost, train_acc, train_prec,
                                train_rec, train_f1_score,
                                train_conf_mat[0][0], train_conf_mat[0][1],
                                train_conf_mat[1][0], train_conf_mat[1][1]))

            iter += 1
            total_steps += 1

            if total_steps > int(FLAGS.max_total_steps):
                break

        if total_steps > int(FLAGS.max_total_steps):
            break

    print("Optimization Finished!")
    sess.run(val_adj_info.op)
    val_cost, val_acc, val_prec, val_rec, val_f1_score, val_conf_mat = incremental_evaluate(
        sess, model, minibatch, FLAGS.batch_size)
    print("Full validation stats:", "val_cost=", "{:.5f}".format(val_cost),
          "val_acc=", "{:.5f}".format(val_acc), "val_prec=",
          "{:.5f}".format(val_prec), "val_rec=", "{:.5f}".format(val_rec),
          "val_f1_score=", "{:.5f}".format(val_f1_score), "val_conf_mat=",
          val_conf_mat)
    with open(results_folder + "/val_stats.txt", "a") as fp:
        fp.write(
            "loss={:.5f} acc={:.5f} prec={:.5f} rec={:.5f} f1={:.5f} tp={:d} fp={:d} fn={:d} tn={:d} time=={:s}\n"
            .format(val_cost, val_acc, val_prec, val_rec, val_f1_score,
                    val_conf_mat[0][0], val_conf_mat[0][1], val_conf_mat[1][0],
                    val_conf_mat[1][1], current_time))
    print("Writing test set stats to file")
    test_cost, test_acc, test_prec, test_rec, test_f1_score, test_conf_mat = incremental_evaluate(
        sess, model, minibatch, FLAGS.batch_size, test=True)
    print("Full test stats:", "test_cost=", "{:.5f}".format(test_cost),
          "test_acc=", "{:.5f}".format(test_acc), "test_prec=",
          "{:.5f}".format(test_prec), "test_rec=", "{:.5f}".format(test_rec),
          "test_f1_score=", "{:.5f}".format(test_f1_score), "test_conf_mat=",
          test_conf_mat)
    with open(results_folder + "/test_stats.txt", "a") as fp:
        fp.write(
            "loss={:.5f} acc={:.5f} prec={:.5f} rec={:.5f} f1={:.5f} tp={:d} fp={:d} fn={:d} tn={:d} time=={:s}\n"
            .format(test_cost, test_acc, test_prec, test_rec, test_f1_score,
                    test_conf_mat[0][0], test_conf_mat[0][1],
                    test_conf_mat[1][0], test_conf_mat[1][1], current_time))
Beispiel #2
0
def train(train_data, test_data=None):

    G = train_data[0]
    features = train_data[1]
    id_map = train_data[2]
    class_map = train_data[4]
    if isinstance(list(class_map.values())[0], list):
        num_classes = len(list(class_map.values())[0])
    else:
        num_classes = len(set(class_map.values()))

    if not features is None:
        # pad with dummy zero vector
        features = np.vstack([features, np.zeros((features.shape[1], ))])

    context_pairs = train_data[3] if FLAGS.random_context else None
    placeholders = construct_placeholders(num_classes)
    minibatch = NodeMinibatchIterator(G,
                                      id_map,
                                      placeholders,
                                      class_map,
                                      num_classes,
                                      batch_size=FLAGS.batch_size,
                                      max_degree=FLAGS.max_degree,
                                      context_pairs=context_pairs)
    adj_info_ph = tf.placeholder(tf.int32, shape=minibatch.adj.shape)
    adj_info = tf.Variable(adj_info_ph, trainable=False, name="adj_info")

    if FLAGS.model == 'graphsage_mean':
        # Create model
        sampler = UniformNeighborSampler(adj_info)
        if FLAGS.samples_3 != 0:
            layer_infos = [
                SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2),
                SAGEInfo("node", sampler, FLAGS.samples_3, FLAGS.dim_2)
            ]
        elif FLAGS.samples_2 != 0:
            layer_infos = [
                SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
            ]
        else:
            layer_infos = [
                SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1)
            ]

        model = SupervisedGraphsage(num_classes,
                                    placeholders,
                                    features,
                                    adj_info,
                                    minibatch.deg,
                                    layer_infos,
                                    model_size=FLAGS.model_size,
                                    sigmoid_loss=FLAGS.sigmoid,
                                    identity_dim=FLAGS.identity_dim,
                                    logging=True)
    elif FLAGS.model == 'gcn':
        # Create model
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, 2 * FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, 2 * FLAGS.dim_2)
        ]

        model = SupervisedGraphsage(num_classes,
                                    placeholders,
                                    features,
                                    adj_info,
                                    minibatch.deg,
                                    layer_infos=layer_infos,
                                    aggregator_type="gcn",
                                    model_size=FLAGS.model_size,
                                    concat=False,
                                    sigmoid_loss=FLAGS.sigmoid,
                                    identity_dim=FLAGS.identity_dim,
                                    logging=True)

    elif FLAGS.model == 'graphsage_seq':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SupervisedGraphsage(num_classes,
                                    placeholders,
                                    features,
                                    adj_info,
                                    minibatch.deg,
                                    layer_infos=layer_infos,
                                    aggregator_type="seq",
                                    model_size=FLAGS.model_size,
                                    sigmoid_loss=FLAGS.sigmoid,
                                    identity_dim=FLAGS.identity_dim,
                                    logging=True)

    elif FLAGS.model == 'graphsage_maxpool':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SupervisedGraphsage(num_classes,
                                    placeholders,
                                    features,
                                    adj_info,
                                    minibatch.deg,
                                    layer_infos=layer_infos,
                                    aggregator_type="maxpool",
                                    model_size=FLAGS.model_size,
                                    sigmoid_loss=FLAGS.sigmoid,
                                    identity_dim=FLAGS.identity_dim,
                                    logging=True)

    elif FLAGS.model == 'graphsage_meanpool':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SupervisedGraphsage(num_classes,
                                    placeholders,
                                    features,
                                    adj_info,
                                    minibatch.deg,
                                    layer_infos=layer_infos,
                                    aggregator_type="meanpool",
                                    model_size=FLAGS.model_size,
                                    sigmoid_loss=FLAGS.sigmoid,
                                    identity_dim=FLAGS.identity_dim,
                                    logging=True)

    else:
        raise Exception('Error: model name unrecognized.')

    config = tf.ConfigProto(log_device_placement=FLAGS.log_device_placement)
    config.gpu_options.allow_growth = True
    #config.gpu_options.per_process_gpu_memory_fraction = GPU_MEM_FRACTION
    config.allow_soft_placement = True

    # Initialize session
    sess = tf.Session(config=config)
    merged = tf.summary.merge_all()
    summary_writer = tf.summary.FileWriter(log_dir(), sess.graph)

    # Init variables
    sess.run(tf.global_variables_initializer(),
             feed_dict={adj_info_ph: minibatch.adj})

    # Train model

    total_steps = 0
    avg_time = 0.0
    epoch_val_costs = []

    train_adj_info = tf.assign(adj_info, minibatch.adj)
    val_adj_info = tf.assign(adj_info, minibatch.test_adj)
    for epoch in range(FLAGS.epochs):
        minibatch.shuffle()

        iter = 0
        print('Epoch: %04d' % (epoch + 1))
        epoch_val_costs.append(0)
        while not minibatch.end():
            # Construct feed dictionary
            feed_dict, labels = minibatch.next_minibatch_feed_dict()
            feed_dict.update({placeholders['dropout']: FLAGS.dropout})

            t = time.time()
            # Training step
            outs = sess.run([merged, model.opt_op, model.loss, model.preds],
                            feed_dict=feed_dict)
            train_cost = outs[2]

            if iter % FLAGS.validate_iter == 0:
                # Validation
                sess.run(val_adj_info.op)
                if FLAGS.validate_batch_size == -1:
                    val_cost, val_f1_mic, val_f1_mac, duration = incremental_evaluate(
                        sess, model, minibatch, FLAGS.batch_size)
                else:
                    val_cost, val_f1_mic, val_f1_mac, duration = evaluate(
                        sess, model, minibatch, FLAGS.validate_batch_size)
                sess.run(train_adj_info.op)
                epoch_val_costs[-1] += val_cost

            if total_steps % FLAGS.print_every == 0:
                summary_writer.add_summary(outs[0], total_steps)

            # Print results
            avg_time = (avg_time * total_steps + time.time() -
                        t) / (total_steps + 1)

            if total_steps % FLAGS.print_every == 0:
                train_f1_mic, train_f1_mac = calc_f1(labels, outs[-1])
                print("Iter:", '%04d' % iter, "train_loss=",
                      "{:.5f}".format(train_cost), "train_f1_mic=",
                      "{:.5f}".format(train_f1_mic), "train_f1_mac=",
                      "{:.5f}".format(train_f1_mac), "val_loss=",
                      "{:.5f}".format(val_cost), "val_f1_mic=",
                      "{:.5f}".format(val_f1_mic), "val_f1_mac=",
                      "{:.5f}".format(val_f1_mac), "time=",
                      "{:.5f}".format(avg_time))

            iter += 1
            total_steps += 1

            if total_steps > FLAGS.max_total_steps:
                break

        if total_steps > FLAGS.max_total_steps:
            break

    print("Optimization Finished!")
    sess.run(val_adj_info.op)
    val_cost, val_f1_mic, val_f1_mac, duration = incremental_evaluate(
        sess, model, minibatch, FLAGS.batch_size)
    print("Full validation stats:", "loss=", "{:.5f}".format(val_cost),
          "f1_micro=", "{:.5f}".format(val_f1_mic), "f1_macro=",
          "{:.5f}".format(val_f1_mac), "time=", "{:.5f}".format(duration))
    with open(log_dir() + "val_stats.txt", "w") as fp:
        fp.write(
            "loss={:.5f} f1_micro={:.5f} f1_macro={:.5f} time={:.5f}".format(
                val_cost, val_f1_mic, val_f1_mac, duration))

    print("Writing test set stats to file (don't peak!)")
    val_cost, val_f1_mic, val_f1_mac, duration = incremental_evaluate(
        sess, model, minibatch, FLAGS.batch_size, test=True)
    with open(log_dir() + "test_stats.txt", "w") as fp:
        fp.write("loss={:.5f} f1_micro={:.5f} f1_macro={:.5f}".format(
            val_cost, val_f1_mic, val_f1_mac))
Beispiel #3
0
    def train(self, train_data, test_data=None):
        print("Training model...")
        timer = Timer()
        timer.tic()

        G = train_data[0]
        features = train_data[1]
        id_map = train_data[2]
        class_map = train_data[4]
        if isinstance(list(class_map.values())[0], list):
            num_classes = len(list(class_map.values())[0])
        else:
            num_classes = len(set(class_map.values()))

        if not features is None:
            # pad with dummy zero vector
            features = np.vstack([features, np.zeros((features.shape[1], ))])

        placeholders = self._construct_placeholders(num_classes)
        minibatch = NodeMinibatchIterator(G,
                                          id_map,
                                          placeholders,
                                          class_map,
                                          num_classes,
                                          batch_size=self.batch_size,
                                          max_degree=self.max_degree)

        adj_info_ph = tf.compat.v1.placeholder(tf.int32,
                                               shape=minibatch.adj.shape)
        adj_info = tf.Variable(adj_info_ph, trainable=False, name="adj_info")

        model = self._create_model(num_classes, placeholders, features,
                                   adj_info, minibatch)

        config = tf.compat.v1.ConfigProto(
            log_device_placement=self.log_device_placement)
        config.gpu_options.allow_growth = True
        config.allow_soft_placement = True

        # Initialize session
        sess = tf.compat.v1.Session(config=config)
        merged = tf.compat.v1.summary.merge_all()
        #        summary_writer = tf.summary.FileWriter(self._log_dir(), sess.graph)

        # Initialize model saver
        saver = tf.compat.v1.train.Saver(max_to_keep=self.epochs)

        # Init variables
        sess.run(tf.compat.v1.global_variables_initializer(),
                 feed_dict={adj_info_ph: minibatch.adj})

        # Train model
        total_steps = 0
        avg_time = 0.0
        epoch_val_costs = []

        train_losses = []
        validation_losses = []

        train_adj_info = tf.compat.v1.assign(adj_info, minibatch.adj)
        val_adj_info = tf.compat.v1.assign(adj_info, minibatch.test_adj)

        for epoch in range(self.epochs):
            minibatch.shuffle()

            iter = 0
            print('Epoch: %04d' % (epoch))
            epoch_val_costs.append(0)
            train_loss_epoch = []
            validation_loss_epoch = []
            while not minibatch.end():
                # Construct feed dictionary
                feed_dict, labels = minibatch.next_minibatch_feed_dict()
                feed_dict.update({placeholders['dropout']: self.dropout})

                t = time.time()
                # Training step
                outs = sess.run(
                    [merged, model.opt_op, model.loss, model.preds],
                    feed_dict=feed_dict)
                train_cost = outs[2]
                train_loss_epoch.append(train_cost)

                if iter % self.validate_iter == 0:
                    # Validation
                    sess.run(val_adj_info.op)
                    if self.validate_batch_size == -1:
                        val_cost, val_f1_mic, val_f1_mac, duration = self._incremental_evaluate(
                            sess, model, minibatch, self.batch_size)
                    else:
                        val_cost, val_f1_mic, val_f1_mac, duration = self._evaluate(
                            sess, model, minibatch, self.validate_batch_size)
                    sess.run(train_adj_info.op)
                    epoch_val_costs[-1] += val_cost
                    validation_loss_epoch.append(val_cost)

#                if total_steps % self.print_every == 0:
#                    summary_writer.add_summary(outs[0], total_steps)

# Print results
                avg_time = (avg_time * total_steps + time.time() -
                            t) / (total_steps + 1)

                if total_steps % self.print_every == 0:
                    train_f1_mic, train_f1_mac = self._calc_f1(
                        labels, outs[-1])
                    print("Iter:", '%04d' % iter, "train_loss=",
                          "{:.5f}".format(train_cost), "train_f1_mic=",
                          "{:.5f}".format(train_f1_mic), "train_f1_mac=",
                          "{:.5f}".format(train_f1_mac), "val_loss=",
                          "{:.5f}".format(val_cost), "val_f1_mic=",
                          "{:.5f}".format(val_f1_mic), "val_f1_mac=",
                          "{:.5f}".format(val_f1_mac), "time=",
                          "{:.5f}".format(avg_time))

                iter += 1
                total_steps += 1

                if total_steps > self.max_total_steps:
                    break

            # Keep track of train and validation losses per epoch
            train_losses.append(sum(train_loss_epoch) / len(train_loss_epoch))
            validation_losses.append(
                sum(validation_loss_epoch) / len(validation_loss_epoch))

            # If the epoch has the lowest validation loss so far
            if validation_losses[-1] == min(validation_losses):
                print(
                    "Minimum validation loss so far ({}) at epoch {}.".format(
                        validation_losses[-1], epoch))
                # Save model at each epoch
                print("Saving model at epoch {}.".format(epoch))
                saver.save(sess, os.path.join(self._log_dir(), "model.ckpt"))

            if total_steps > self.max_total_steps:
                break

        print("Optimization Finished!")

        training_time = timer.toc()
        self._plot_losses(train_losses, validation_losses)
        self._print_stats(train_losses, validation_losses, training_time)

        sess.run(val_adj_info.op)
        val_cost, val_f1_mic, val_f1_mac, duration = self._incremental_evaluate(
            sess, model, minibatch, self.batch_size)
        print("Full validation stats:", "loss=", "{:.5f}".format(val_cost),
              "f1_micro=", "{:.5f}".format(val_f1_mic), "f1_macro=",
              "{:.5f}".format(val_f1_mac), "time=", "{:.5f}".format(duration))
        with open(self._log_dir() + "val_stats.txt", "w") as fp:
            fp.write("loss={:.5f} f1_micro={:.5f} f1_macro={:.5f} time={:.5f}".
                     format(val_cost, val_f1_mic, val_f1_mac, duration))