コード例 #1
0
ファイル: supervised_train.py プロジェクト: naz947/Bug_Triage
def train(train_data, test_data=None):

    G = train_data[0]
    features = train_data[1]
    id_map = train_data[2]
    class_map = train_data[4]
    if isinstance(list(class_map.values())[0], list):
        num_classes = len(list(class_map.values())[0])
    else:
        num_classes = len(set(class_map.values()))

    if not features is None:
        # pad with dummy zero vector
        features = np.vstack([features, np.zeros((features.shape[1], ))])

    context_pairs = train_data[3] if FLAGS.random_context else None
    placeholders = construct_placeholders(num_classes)
    minibatch = NodeMinibatchIterator(G,
                                      id_map,
                                      placeholders,
                                      class_map,
                                      num_classes,
                                      batch_size=FLAGS.batch_size,
                                      max_degree=FLAGS.max_degree,
                                      context_pairs=context_pairs)
    adj_info_ph = tf.placeholder(tf.int32, shape=minibatch.adj.shape)
    adj_info = tf.Variable(adj_info_ph, trainable=False, name="adj_info")

    if FLAGS.model == 'graphsage_mean':
        # Create model
        sampler = UniformNeighborSampler(adj_info)
        if FLAGS.samples_3 != 0:
            layer_infos = [
                SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2),
                SAGEInfo("node", sampler, FLAGS.samples_3, FLAGS.dim_2)
            ]
        elif FLAGS.samples_2 != 0:
            layer_infos = [
                SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
            ]
        else:
            layer_infos = [
                SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1)
            ]

        model = SupervisedGraphsage(num_classes,
                                    placeholders,
                                    features,
                                    adj_info,
                                    minibatch.deg,
                                    layer_infos,
                                    model_size=FLAGS.model_size,
                                    sigmoid_loss=FLAGS.sigmoid,
                                    identity_dim=FLAGS.identity_dim,
                                    logging=True)
    elif FLAGS.model == 'gcn':
        # Create model
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, 2 * FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, 2 * FLAGS.dim_2)
        ]

        model = SupervisedGraphsage(num_classes,
                                    placeholders,
                                    features,
                                    adj_info,
                                    minibatch.deg,
                                    layer_infos=layer_infos,
                                    aggregator_type="gcn",
                                    model_size=FLAGS.model_size,
                                    concat=False,
                                    sigmoid_loss=FLAGS.sigmoid,
                                    identity_dim=FLAGS.identity_dim,
                                    logging=True)

    elif FLAGS.model == 'graphsage_seq':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SupervisedGraphsage(num_classes,
                                    placeholders,
                                    features,
                                    adj_info,
                                    minibatch.deg,
                                    layer_infos=layer_infos,
                                    aggregator_type="seq",
                                    model_size=FLAGS.model_size,
                                    sigmoid_loss=FLAGS.sigmoid,
                                    identity_dim=FLAGS.identity_dim,
                                    logging=True)

    elif FLAGS.model == 'graphsage_maxpool':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SupervisedGraphsage(num_classes,
                                    placeholders,
                                    features,
                                    adj_info,
                                    minibatch.deg,
                                    layer_infos=layer_infos,
                                    aggregator_type="maxpool",
                                    model_size=FLAGS.model_size,
                                    sigmoid_loss=FLAGS.sigmoid,
                                    identity_dim=FLAGS.identity_dim,
                                    logging=True)

    elif FLAGS.model == 'graphsage_meanpool':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SupervisedGraphsage(num_classes,
                                    placeholders,
                                    features,
                                    adj_info,
                                    minibatch.deg,
                                    layer_infos=layer_infos,
                                    aggregator_type="meanpool",
                                    model_size=FLAGS.model_size,
                                    sigmoid_loss=FLAGS.sigmoid,
                                    identity_dim=FLAGS.identity_dim,
                                    logging=True)

    else:
        raise Exception('Error: model name unrecognized.')

    config = tf.ConfigProto(log_device_placement=FLAGS.log_device_placement)
    config.gpu_options.allow_growth = True
    #config.gpu_options.per_process_gpu_memory_fraction = GPU_MEM_FRACTION
    config.allow_soft_placement = True

    # Initialize session
    sess = tf.Session(config=config)
    merged = tf.summary.merge_all()
    summary_writer = tf.summary.FileWriter(log_dir(), sess.graph)

    # Init variables
    sess.run(tf.global_variables_initializer(),
             feed_dict={adj_info_ph: minibatch.adj})

    # Train model

    total_steps = 0
    avg_time = 0.0
    epoch_val_costs = []

    train_adj_info = tf.assign(adj_info, minibatch.adj)
    val_adj_info = tf.assign(adj_info, minibatch.test_adj)
    for epoch in range(FLAGS.epochs):
        minibatch.shuffle()

        iter = 0
        print('Epoch: %04d' % (epoch + 1))
        epoch_val_costs.append(0)
        while not minibatch.end():
            # Construct feed dictionary
            feed_dict, labels = minibatch.next_minibatch_feed_dict()
            feed_dict.update({placeholders['dropout']: FLAGS.dropout})

            t = time.time()
            # Training step
            outs = sess.run([merged, model.opt_op, model.loss, model.preds],
                            feed_dict=feed_dict)
            train_cost = outs[2]

            if iter % FLAGS.validate_iter == 0:
                # Validation
                sess.run(val_adj_info.op)
                if FLAGS.validate_batch_size == -1:
                    val_cost, val_f1_mic, val_f1_mac, duration = incremental_evaluate(
                        sess, model, minibatch, FLAGS.batch_size)
                else:
                    val_cost, val_f1_mic, val_f1_mac, duration = evaluate(
                        sess, model, minibatch, FLAGS.validate_batch_size)
                sess.run(train_adj_info.op)
                epoch_val_costs[-1] += val_cost

            if total_steps % FLAGS.print_every == 0:
                summary_writer.add_summary(outs[0], total_steps)

            # Print results
            avg_time = (avg_time * total_steps + time.time() -
                        t) / (total_steps + 1)

            if total_steps % FLAGS.print_every == 0:
                train_f1_mic, train_f1_mac = calc_f1(labels, outs[-1])
                print("Iter:", '%04d' % iter, "train_loss=",
                      "{:.5f}".format(train_cost), "train_f1_mic=",
                      "{:.5f}".format(train_f1_mic), "train_f1_mac=",
                      "{:.5f}".format(train_f1_mac), "val_loss=",
                      "{:.5f}".format(val_cost), "val_f1_mic=",
                      "{:.5f}".format(val_f1_mic), "val_f1_mac=",
                      "{:.5f}".format(val_f1_mac), "time=",
                      "{:.5f}".format(avg_time))

            iter += 1
            total_steps += 1

            if total_steps > FLAGS.max_total_steps:
                break

        if total_steps > FLAGS.max_total_steps:
            break

    print("Optimization Finished!")
    sess.run(val_adj_info.op)
    val_cost, val_f1_mic, val_f1_mac, duration = incremental_evaluate(
        sess, model, minibatch, FLAGS.batch_size)
    print("Full validation stats:", "loss=", "{:.5f}".format(val_cost),
          "f1_micro=", "{:.5f}".format(val_f1_mic), "f1_macro=",
          "{:.5f}".format(val_f1_mac), "time=", "{:.5f}".format(duration))
    with open(log_dir() + "val_stats.txt", "w") as fp:
        fp.write(
            "loss={:.5f} f1_micro={:.5f} f1_macro={:.5f} time={:.5f}".format(
                val_cost, val_f1_mic, val_f1_mac, duration))

    print("Writing test set stats to file (don't peak!)")
    val_cost, val_f1_mic, val_f1_mac, duration = incremental_evaluate(
        sess, model, minibatch, FLAGS.batch_size, test=True)
    with open(log_dir() + "test_stats.txt", "w") as fp:
        fp.write("loss={:.5f} f1_micro={:.5f} f1_macro={:.5f}".format(
            val_cost, val_f1_mic, val_f1_mac))
コード例 #2
0
def train(train_data):

    G = train_data[0]
    features = train_data[1]
    labels = train_data[2]
    train_nodes = train_data[3]
    test_nodes = train_data[4]
    val_nodes = train_data[5]
    num_classes = 2

    if not features is None:
        # pad with dummy zero vector
        features = np.vstack([features, np.zeros((features.shape[1], ))])

    placeholders = construct_placeholders(num_classes)
    minibatch = NodeMinibatchIterator(G,
                                      placeholders,
                                      labels,
                                      train_nodes,
                                      test_nodes,
                                      val_nodes,
                                      num_classes,
                                      batch_size=FLAGS.batch_size,
                                      max_degree=FLAGS.max_degree)
    adj_info_ph = tf.placeholder(tf.int32, shape=minibatch.adj.shape)
    adj_info = tf.Variable(adj_info_ph, trainable=False, name="adj_info")

    if FLAGS.model == 'graphsage_mean':
        # Create model
        sampler = UniformNeighborSampler(adj_info)
        if FLAGS.samples_3 != 0:
            layer_infos = [
                SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2),
                SAGEInfo("node", sampler, FLAGS.samples_3, FLAGS.dim_2)
            ]
        elif FLAGS.samples_2 != 0:
            layer_infos = [
                SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
            ]
        else:
            layer_infos = [
                SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1)
            ]

    model = SupervisedGraphsage(num_classes,
                                placeholders,
                                features,
                                adj_info,
                                minibatch.deg,
                                layer_infos,
                                model_size=FLAGS.model_size,
                                sigmoid_loss=FLAGS.sigmoid,
                                identity_dim=FLAGS.identity_dim,
                                logging=True)

    config = tf.ConfigProto(log_device_placement=FLAGS.log_device_placement)
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True

    # Initialize session
    sess = tf.Session(config=config)
    merged = tf.summary.merge_all()
    summary_writer = tf.summary.FileWriter(results_folder, sess.graph)

    # Init variables
    sess.run(tf.global_variables_initializer(),
             feed_dict={adj_info_ph: minibatch.adj})

    # Train model
    total_steps = 0
    avg_time = 0.0
    epoch_val_costs = []

    train_adj_info = tf.assign(adj_info, minibatch.adj)
    val_adj_info = tf.assign(adj_info, minibatch.test_adj)
    for epoch in range(FLAGS.epochs):
        minibatch.shuffle()

        iter = 0
        print('Epoch: %04d' % (epoch + 1))
        epoch_val_costs.append(0)
        while not minibatch.end():
            # Construct feed dictionary
            feed_dict, labels, batch = minibatch.next_minibatch_feed_dict()
            feed_dict.update({placeholders['dropout']: FLAGS.dropout})

            t = time.time()
            # Training step
            outs = sess.run([merged, model.opt_op, model.loss, model.preds],
                            feed_dict=feed_dict)
            train_cost = outs[2]

            if iter % FLAGS.validate_iter == 0:
                # Validation
                sess.run(val_adj_info.op)
                if FLAGS.validate_batch_size == -1:
                    val_cost, acc, prec, rec, f1_score, conf_mat = incremental_evaluate(
                        sess, model, minibatch, FLAGS.batch_size)
                else:
                    val_cost, acc, prec, rec, f1_score, conf_mat, fpr, tpr, thresholds = evaluate(
                        sess, model, minibatch, FLAGS.validate_batch_size)
                sess.run(train_adj_info.op)
                epoch_val_costs[-1] += val_cost

            if total_steps % FLAGS.print_every == 0:
                summary_writer.add_summary(outs[0], total_steps)

            # Print results
            avg_time = (avg_time * total_steps + time.time() -
                        t) / (total_steps + 1)

            if total_steps % FLAGS.print_every == 0:
                train_acc, train_prec, train_rec, train_f1_score, train_conf_mat, fpr, tpr, thresholds = eval(
                    labels, outs[-1])
                print("Iter:", '%04d' % iter, "train_loss=",
                      "{:.5f}".format(train_cost), "train_accuracy=",
                      "{:.5f}".format(train_acc), "train_precision=",
                      "{:.5f}".format(train_prec), "train_recall=",
                      "{:.5f}".format(train_rec), "train_f1_score=",
                      "{:.5f}".format(train_f1_score))
                with open(results_folder + "/train_stats.txt", "a") as fp:
                    fp.write(
                        "Iter:{:d} loss={:.5f} acc={:.5f} prec={:.5f} rec={:.5f} f1={:.5f} tp={:d} fp={:d} fn={:d} tn={:d}\n"
                        .format(iter, train_cost, train_acc, train_prec,
                                train_rec, train_f1_score,
                                train_conf_mat[0][0], train_conf_mat[0][1],
                                train_conf_mat[1][0], train_conf_mat[1][1]))

            iter += 1
            total_steps += 1

            if total_steps > int(FLAGS.max_total_steps):
                break

        if total_steps > int(FLAGS.max_total_steps):
            break

    print("Optimization Finished!")
    sess.run(val_adj_info.op)
    val_cost, val_acc, val_prec, val_rec, val_f1_score, val_conf_mat = incremental_evaluate(
        sess, model, minibatch, FLAGS.batch_size)
    print("Full validation stats:", "val_cost=", "{:.5f}".format(val_cost),
          "val_acc=", "{:.5f}".format(val_acc), "val_prec=",
          "{:.5f}".format(val_prec), "val_rec=", "{:.5f}".format(val_rec),
          "val_f1_score=", "{:.5f}".format(val_f1_score), "val_conf_mat=",
          val_conf_mat)
    with open(results_folder + "/val_stats.txt", "a") as fp:
        fp.write(
            "loss={:.5f} acc={:.5f} prec={:.5f} rec={:.5f} f1={:.5f} tp={:d} fp={:d} fn={:d} tn={:d} time=={:s}\n"
            .format(val_cost, val_acc, val_prec, val_rec, val_f1_score,
                    val_conf_mat[0][0], val_conf_mat[0][1], val_conf_mat[1][0],
                    val_conf_mat[1][1], current_time))
    print("Writing test set stats to file")
    test_cost, test_acc, test_prec, test_rec, test_f1_score, test_conf_mat = incremental_evaluate(
        sess, model, minibatch, FLAGS.batch_size, test=True)
    print("Full test stats:", "test_cost=", "{:.5f}".format(test_cost),
          "test_acc=", "{:.5f}".format(test_acc), "test_prec=",
          "{:.5f}".format(test_prec), "test_rec=", "{:.5f}".format(test_rec),
          "test_f1_score=", "{:.5f}".format(test_f1_score), "test_conf_mat=",
          test_conf_mat)
    with open(results_folder + "/test_stats.txt", "a") as fp:
        fp.write(
            "loss={:.5f} acc={:.5f} prec={:.5f} rec={:.5f} f1={:.5f} tp={:d} fp={:d} fn={:d} tn={:d} time=={:s}\n"
            .format(test_cost, test_acc, test_prec, test_rec, test_f1_score,
                    test_conf_mat[0][0], test_conf_mat[0][1],
                    test_conf_mat[1][0], test_conf_mat[1][1], current_time))
コード例 #3
0
    def inference(self, test_data, gpu_mem_fraction=None):
        print("Inference.")
        timer = Timer()
        timer.tic()

        G = test_data[0]
        features = test_data[1]
        id_map = test_data[2]
        class_map = test_data[4]
        if isinstance(list(class_map.values())[0], list):
            num_classes = len(list(class_map.values())[0])
        else:
            num_classes = len(set(class_map.values()))

        if not features is None:
            # pad with dummy zero vector
            features = np.vstack([features, np.zeros((features.shape[1], ))])

        placeholders = self._construct_placeholders(num_classes)
        minibatch = NodeMinibatchIterator(G,
                                          id_map,
                                          placeholders,
                                          class_map,
                                          num_classes,
                                          batch_size=self.batch_size,
                                          max_degree=self.max_degree)

        adj_info_ph = tf.compat.v1.placeholder(tf.int32,
                                               shape=minibatch.adj.shape)
        adj_info = tf.Variable(adj_info_ph, trainable=False, name="adj_info")

        model = self._create_model(num_classes, placeholders, features,
                                   adj_info, minibatch)

        config = tf.compat.v1.ConfigProto(
            log_device_placement=self.log_device_placement)
        config.gpu_options.allow_growth = True
        config.allow_soft_placement = True

        # Initialize session
        sess = tf.compat.v1.Session(config=config)
        merged = tf.compat.v1.summary.merge_all()
        #        summary_writer = tf.summary.FileWriter(self._log_dir(), sess.graph)

        # Initialize model saver
        saver = tf.compat.v1.train.Saver(max_to_keep=self.epochs)

        # Init variables
        sess.run(tf.compat.v1.global_variables_initializer(),
                 feed_dict={adj_info_ph: minibatch.adj})

        # Restore model
        print("Restoring trained model.")
        checkpoint_file = os.path.join(self._log_dir(), "model.ckpt")
        ckpt = tf.compat.v1.train.get_checkpoint_state(checkpoint_file)
        if checkpoint_file:
            saver.restore(sess, checkpoint_file)
            print("Model restored.")
        else:
            print("This model checkpoint does not exist. The model might " +
                  "not be trained yet or the checkpoint is invalid.")

        val_adj_info = tf.compat.v1.assign(adj_info, minibatch.test_adj)
        sess.run(val_adj_info.op)

        print("Computing predictions...")
        t_test = time.time()
        finished = False
        val_losses = []
        val_preds = []
        nodes = []
        iter_num = 0
        while not finished:
            feed_dict_val, _, finished, nodes_subset = minibatch.incremental_node_val_feed_dict(
                self.batch_size, iter_num, test=True)
            node_outs_val = sess.run([model.preds, model.loss],
                                     feed_dict=feed_dict_val)
            val_preds.append(node_outs_val[0])
            val_losses.append(node_outs_val[1])
            nodes.extend(nodes_subset)
            iter_num += 1
        val_preds = np.vstack(val_preds)
        print("Computed.")

        # Return only the embeddings of the test nodes
        test_preds_ids = {}
        for i, node in enumerate(nodes):
            test_preds_ids[node] = i
        test_nodes = [n for n in G.nodes() if G.node[n]['test']]
        test_preds = val_preds[[test_preds_ids[id] for id in test_nodes]]
        timer.toc()
        sess.close()
        return test_nodes, test_preds
コード例 #4
0
ファイル: utils.py プロジェクト: zhh0998/CS-GNN
def load_data(prefix,
              num_layers=1,
              batch_size=1,
              concat=True,
              sample_number=50,
              directed=False):
    with futures.ProcessPoolExecutor(max_workers=5) as executor:
        # 1. read data
        start_time = time.time()
        futs = [
            executor.submit(loadG, prefix, directed),
            executor.submit(loadjson, prefix + '-class_map.json'),
        ]
        if os.path.exists(prefix + '-feats.npy'):
            feats = np.load(prefix + '-feats.npy')
        else:
            feats = None

        # 2. process preparation
        class_map = futs[1].result()
        if isinstance(list(class_map.values())[0], list):
            lab_conversion = lambda n: n
        else:
            lab_conversion = lambda n: int(n)
        G = futs[0].result()

        # 3. process data
        start_time = time.time()
        if isinstance(G.nodes()[0], int):
            conversion = lambda n: int(n)
        else:
            conversion = lambda n: n
        fut = executor.submit(process_graph, G)
        class_map = convert_dict(class_map, conversion, lab_conversion)
        # if single label
        for k, v in class_map.items():
            if type(v) != list:
                class_map = convert_list(class_map)
            break
        G = fut.result()

        # 4. division
        start_time = time.time()
        train_nodes = [
            n for n in G.nodes()
            if not G.node[n]['test'] and not G.node[n]['val']
        ]
        val_nodes = [
            n for n in G.nodes() if not G.node[n]['test'] and G.node[n]['val']
        ]
        test_nodes = [
            n for n in G.nodes() if G.node[n]['test'] and not G.node[n]['val']
        ]
        unlabeled_nodes = [
            n for n in G.nodes() if G.node[n]['test'] and G.node[n]['val']
        ]
        class_map = convert_ndarray(class_map)
        # remove useless nodes
        if len(unlabeled_nodes) > 0:
            G, feats, class_map = rm_useless(G, feats, class_map,
                                             unlabeled_nodes, num_layers)
            train_nodes = [
                n for n in G.nodes()
                if not G.node[n]['test'] and not G.node[n]['val']
            ]
            val_nodes = [
                n for n in G.nodes()
                if not G.node[n]['test'] and G.node[n]['val']
            ]
            test_nodes = [
                n for n in G.nodes()
                if G.node[n]['test'] and not G.node[n]['val']
            ]
            unlabeled_nodes = [
                n for n in G.nodes() if G.node[n]['test'] and G.node[n]['val']
            ]
        # double check
        if len(class_map) != len(train_nodes) + len(val_nodes) + len(
                test_nodes) + len(unlabeled_nodes):
            raise Exception('Error: repeat node id!')
        if max([n for n in G.nodes()]) != G.number_of_nodes() - 1:
            raise Exception('Error: node id out of range!')

        # 5. encode topology features
        start_time = time.time()
        if os.path.exists(prefix + '-feats_t.npy'):
            feats_t = np.load(prefix + '-feats_t.npy')
            generate_tf = False
        else:
            feats_t = None
            generate_tf = True

        # 6. post process
        train_ids = np.array([
            n for n in G.nodes()
            if not G.node[n]['val'] and not G.node[n]['test']
        ])
        train_feats = feats[train_ids]
        scaler = StandardScaler()
        scaler.fit(train_feats)
        feats = scaler.transform(feats)
        if not generate_tf:
            train_feats_t = feats_t[train_ids]
            scaler.fit(train_feats_t)
            feats_t = scaler.transform(feats_t)
        print("load data in", "{:.5f}".format(time.time() - start_time),
              "seconds")

        # 7. minibatch
        print('start minibatch for train, val, test ...')
        start_time = time.time()
        G.remove_edges_from(G.selfloop_edges())
        train_subgraphs, train_subfeats, train_subfeats_t, train_sublabels, train_submasks = NodeMinibatchIterator(G, num_layers, \
                    batch_size, train_nodes, feats, feats_t, class_map, concat, generate_tf, prefix, sample_number).get_subgraphs()
        val_subgraphs, val_subfeats, val_subfeats_t, val_sublabels, val_submasks = NodeMinibatchIterator(G, num_layers, \
                    batch_size, val_nodes, feats, feats_t, class_map, concat, generate_tf, prefix, sample_number).get_subgraphs()
        test_subgraphs, test_subfeats, test_subfeats_t, test_sublabels, test_submasks = NodeMinibatchIterator(G, num_layers, \
                    batch_size, test_nodes, feats, feats_t, class_map, concat, generate_tf, prefix, sample_number).get_subgraphs()
        print('Done with minibatch within {:.5f} seconds, start training...'.
              format(time.time() - start_time))

    return G, train_subgraphs, train_subfeats, train_subfeats_t, train_sublabels, train_submasks, val_subgraphs, val_subfeats, \
        val_subfeats_t, val_sublabels, val_submasks, test_subgraphs, test_subfeats, test_subfeats_t, test_sublabels, test_submasks
コード例 #5
0
    def train(self, train_data, test_data=None):
        print("Training model...")
        timer = Timer()
        timer.tic()

        G = train_data[0]
        features = train_data[1]
        id_map = train_data[2]
        class_map = train_data[4]
        if isinstance(list(class_map.values())[0], list):
            num_classes = len(list(class_map.values())[0])
        else:
            num_classes = len(set(class_map.values()))

        if not features is None:
            # pad with dummy zero vector
            features = np.vstack([features, np.zeros((features.shape[1], ))])

        placeholders = self._construct_placeholders(num_classes)
        minibatch = NodeMinibatchIterator(G,
                                          id_map,
                                          placeholders,
                                          class_map,
                                          num_classes,
                                          batch_size=self.batch_size,
                                          max_degree=self.max_degree)

        adj_info_ph = tf.compat.v1.placeholder(tf.int32,
                                               shape=minibatch.adj.shape)
        adj_info = tf.Variable(adj_info_ph, trainable=False, name="adj_info")

        model = self._create_model(num_classes, placeholders, features,
                                   adj_info, minibatch)

        config = tf.compat.v1.ConfigProto(
            log_device_placement=self.log_device_placement)
        config.gpu_options.allow_growth = True
        config.allow_soft_placement = True

        # Initialize session
        sess = tf.compat.v1.Session(config=config)
        merged = tf.compat.v1.summary.merge_all()
        #        summary_writer = tf.summary.FileWriter(self._log_dir(), sess.graph)

        # Initialize model saver
        saver = tf.compat.v1.train.Saver(max_to_keep=self.epochs)

        # Init variables
        sess.run(tf.compat.v1.global_variables_initializer(),
                 feed_dict={adj_info_ph: minibatch.adj})

        # Train model
        total_steps = 0
        avg_time = 0.0
        epoch_val_costs = []

        train_losses = []
        validation_losses = []

        train_adj_info = tf.compat.v1.assign(adj_info, minibatch.adj)
        val_adj_info = tf.compat.v1.assign(adj_info, minibatch.test_adj)

        for epoch in range(self.epochs):
            minibatch.shuffle()

            iter = 0
            print('Epoch: %04d' % (epoch))
            epoch_val_costs.append(0)
            train_loss_epoch = []
            validation_loss_epoch = []
            while not minibatch.end():
                # Construct feed dictionary
                feed_dict, labels = minibatch.next_minibatch_feed_dict()
                feed_dict.update({placeholders['dropout']: self.dropout})

                t = time.time()
                # Training step
                outs = sess.run(
                    [merged, model.opt_op, model.loss, model.preds],
                    feed_dict=feed_dict)
                train_cost = outs[2]
                train_loss_epoch.append(train_cost)

                if iter % self.validate_iter == 0:
                    # Validation
                    sess.run(val_adj_info.op)
                    if self.validate_batch_size == -1:
                        val_cost, val_f1_mic, val_f1_mac, duration = self._incremental_evaluate(
                            sess, model, minibatch, self.batch_size)
                    else:
                        val_cost, val_f1_mic, val_f1_mac, duration = self._evaluate(
                            sess, model, minibatch, self.validate_batch_size)
                    sess.run(train_adj_info.op)
                    epoch_val_costs[-1] += val_cost
                    validation_loss_epoch.append(val_cost)

#                if total_steps % self.print_every == 0:
#                    summary_writer.add_summary(outs[0], total_steps)

# Print results
                avg_time = (avg_time * total_steps + time.time() -
                            t) / (total_steps + 1)

                if total_steps % self.print_every == 0:
                    train_f1_mic, train_f1_mac = self._calc_f1(
                        labels, outs[-1])
                    print("Iter:", '%04d' % iter, "train_loss=",
                          "{:.5f}".format(train_cost), "train_f1_mic=",
                          "{:.5f}".format(train_f1_mic), "train_f1_mac=",
                          "{:.5f}".format(train_f1_mac), "val_loss=",
                          "{:.5f}".format(val_cost), "val_f1_mic=",
                          "{:.5f}".format(val_f1_mic), "val_f1_mac=",
                          "{:.5f}".format(val_f1_mac), "time=",
                          "{:.5f}".format(avg_time))

                iter += 1
                total_steps += 1

                if total_steps > self.max_total_steps:
                    break

            # Keep track of train and validation losses per epoch
            train_losses.append(sum(train_loss_epoch) / len(train_loss_epoch))
            validation_losses.append(
                sum(validation_loss_epoch) / len(validation_loss_epoch))

            # If the epoch has the lowest validation loss so far
            if validation_losses[-1] == min(validation_losses):
                print(
                    "Minimum validation loss so far ({}) at epoch {}.".format(
                        validation_losses[-1], epoch))
                # Save model at each epoch
                print("Saving model at epoch {}.".format(epoch))
                saver.save(sess, os.path.join(self._log_dir(), "model.ckpt"))

            if total_steps > self.max_total_steps:
                break

        print("Optimization Finished!")

        training_time = timer.toc()
        self._plot_losses(train_losses, validation_losses)
        self._print_stats(train_losses, validation_losses, training_time)

        sess.run(val_adj_info.op)
        val_cost, val_f1_mic, val_f1_mac, duration = self._incremental_evaluate(
            sess, model, minibatch, self.batch_size)
        print("Full validation stats:", "loss=", "{:.5f}".format(val_cost),
              "f1_micro=", "{:.5f}".format(val_f1_mic), "f1_macro=",
              "{:.5f}".format(val_f1_mac), "time=", "{:.5f}".format(duration))
        with open(self._log_dir() + "val_stats.txt", "w") as fp:
            fp.write("loss={:.5f} f1_micro={:.5f} f1_macro={:.5f} time={:.5f}".
                     format(val_cost, val_f1_mic, val_f1_mac, duration))