Пример #1
0
def getSampledBatchIterator(G, id_map, class_map, num_classes, batch_size):
    placeholders = {
        'labels':
        tf.placeholder(tf.float32, shape=(None, num_classes), name='labels'),
        'batch':
        tf.placeholder(tf.int32, shape=(None), name='batch1'),
        'dropout':
        tf.placeholder_with_default(0., shape=(), name='dropout'),
        'batch_size':
        tf.placeholder(tf.int32, name='batch_size'),
        'hop1':
        tf.placeholder(tf.int32, shape=(None), name='hop1'),
        'hop2':
        tf.placeholder(tf.int32, shape=(None), name='hop2'),
    }

    layer_infos_top_down = [
        SAGEInfo("node", None, 25, 128),
        SAGEInfo("node", None, 10, 128)
    ]

    minibatch = NodeMinibatchIteratorWithKHop(
        G,
        id_map,
        placeholders,
        class_map,
        num_classes,
        layer_infos_top_down,
        batch_size=batch_size,
        max_degree=128,
    )
    return minibatch
Пример #2
0
def train(train_data, test_data=None):

    G = train_data[0]  # [z]: networkx.Graph
    features = train_data[1]  # [z]: |V|xD
    id_map = train_data[2]  #
    class_map = train_data[4]

    if isinstance(list(class_map.values())[0], list):
        num_classes = len(list(class_map.values())[0])
    else:
        num_classes = len(set(class_map.values()))

    if not features is None:
        # pad with dummy zero vector
        features = np.vstack([features, np.zeros((features.shape[1], ))])

    # [z]: what is context? -- a list of random walks
    context_pairs = train_data[3] if FLAGS.random_context else None
    placeholders = construct_placeholders(num_classes)
    # [z]: minibatch.adj is a adj list of a uniform graph sampled from the input graph
    minibatch = NodeMinibatchIterator(G,
                                      id_map,
                                      placeholders,
                                      class_map,
                                      num_classes,
                                      batch_size=FLAGS.batch_size,
                                      max_degree=FLAGS.max_degree,
                                      context_pairs=context_pairs)
    # [z]: adj_info_ph is of R^{|V|xFLAGS.max_degree}
    # [z]: minibatch.adj is R^{|V|xD}
    adj_info_ph = tf.placeholder(tf.int32, shape=minibatch.adj.shape)
    adj_info = tf.Variable(adj_info_ph, trainable=False, name="adj_info")

    if FLAGS.model == 'graphsage_mean':
        # Create model
        sampler = UniformNeighborSampler(adj_info)
        if FLAGS.samples_3 != 0:
            # [z]: SAGEInfo: [layer_name, neigh_sampler, num_samples, output_dim]
            # [z]: NOTE: i should probably start from single layer model. i.e., FLAGS.samples_2 = 0
            layer_infos = [
                SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2),
                SAGEInfo("node", sampler, FLAGS.samples_3, FLAGS.dim_2)
            ]
        elif FLAGS.samples_2 != 0:
            layer_infos = [
                SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
            ]
        else:
            layer_infos = [
                SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1)
            ]

        model = SupervisedGraphsage(
            num_classes,
            placeholders,
            features,  # [z]: |V|xD
            adj_info,
            minibatch.deg,
            layer_infos,
            model_size=FLAGS.model_size,  # [z]: can be small or big?
            sigmoid_loss=FLAGS.sigmoid,
            identity_dim=FLAGS.identity_dim,
            logging=True)
    elif FLAGS.model == 'gcn':
        # Create model
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, 2 * FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, 2 * FLAGS.dim_2)
        ]

        model = SupervisedGraphsage(num_classes,
                                    placeholders,
                                    features,
                                    adj_info,
                                    minibatch.deg,
                                    layer_infos=layer_infos,
                                    aggregator_type="gcn",
                                    model_size=FLAGS.model_size,
                                    concat=False,
                                    sigmoid_loss=FLAGS.sigmoid,
                                    identity_dim=FLAGS.identity_dim,
                                    logging=True)

    elif FLAGS.model == 'graphsage_seq':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SupervisedGraphsage(num_classes,
                                    placeholders,
                                    features,
                                    adj_info,
                                    minibatch.deg,
                                    layer_infos=layer_infos,
                                    aggregator_type="seq",
                                    model_size=FLAGS.model_size,
                                    sigmoid_loss=FLAGS.sigmoid,
                                    identity_dim=FLAGS.identity_dim,
                                    logging=True)

    elif FLAGS.model == 'graphsage_maxpool':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SupervisedGraphsage(num_classes,
                                    placeholders,
                                    features,
                                    adj_info,
                                    minibatch.deg,
                                    layer_infos=layer_infos,
                                    aggregator_type="maxpool",
                                    model_size=FLAGS.model_size,
                                    sigmoid_loss=FLAGS.sigmoid,
                                    identity_dim=FLAGS.identity_dim,
                                    logging=True)

    elif FLAGS.model == 'graphsage_meanpool':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SupervisedGraphsage(num_classes,
                                    placeholders,
                                    features,
                                    adj_info,
                                    minibatch.deg,
                                    layer_infos=layer_infos,
                                    aggregator_type="meanpool",
                                    model_size=FLAGS.model_size,
                                    sigmoid_loss=FLAGS.sigmoid,
                                    identity_dim=FLAGS.identity_dim,
                                    logging=True)

    else:
        raise Exception('Error: model name unrecognized.')

    config = tf.ConfigProto(log_device_placement=FLAGS.log_device_placement)
    config.gpu_options.allow_growth = True
    #config.gpu_options.per_process_gpu_memory_fraction = GPU_MEM_FRACTION
    config.allow_soft_placement = True

    # Initialize session
    sess = tf.Session(config=config)
    merged = tf.summary.merge_all()
    summary_writer = tf.summary.FileWriter(log_dir(), sess.graph)

    # Init variables
    sess.run(tf.global_variables_initializer(),
             feed_dict={adj_info_ph: minibatch.adj})

    # Train model

    total_steps = 0
    avg_time = 0.0
    epoch_val_costs = []

    # [z]: adj_info, this adj is for the whole graph
    train_adj_info = tf.assign(adj_info, minibatch.adj)
    # [z]: minibatch.test_adj is also the adj of the whole graph!
    val_adj_info = tf.assign(adj_info, minibatch.test_adj)
    for epoch in range(FLAGS.epochs):
        minibatch.shuffle()

        iter = 0
        print('Epoch: %04d' % (epoch + 1))
        epoch_val_costs.append(0)
        while not minibatch.end():
            # Construct feed dictionary
            feed_dict, labels = minibatch.next_minibatch_feed_dict()
            feed_dict.update({placeholders['dropout']: FLAGS.dropout})

            t = time.time()
            # Training step
            # [z]: actually calculate the values in SupervisedGraphsage.build()
            # [z]: feed_dict should be fed to a tf.placeholder
            # [z]: opt_op is applying gradients to the params, but it does not return anything.
            # [z]: model.preds is R^{512x121}
            outs = sess.run([merged, model.opt_op, model.loss, model.preds],
                            feed_dict=feed_dict)
            #for k in z.debug_vars.keys():
            #    print('-------------- {} --------------'.format(k))
            #    dbg = sess.run(z.debug_vars[k], feed_dict=feed_dict)
            #    import pdb; pdb.set_trace()
            train_cost = outs[2]

            if iter % FLAGS.validate_iter == 0:
                # Validation
                sess.run(val_adj_info.op)
                if FLAGS.validate_batch_size == -1:
                    val_cost, val_f1_mic, val_f1_mac, duration = incremental_evaluate(
                        sess, model, minibatch, FLAGS.batch_size)
                else:
                    val_cost, val_f1_mic, val_f1_mac, duration = evaluate(
                        sess, model, minibatch, FLAGS.validate_batch_size)
                sess.run(train_adj_info.op)
                epoch_val_costs[-1] += val_cost

            if total_steps % FLAGS.print_every == 0:
                summary_writer.add_summary(outs[0], total_steps)

            # Print results
            avg_time = (avg_time * total_steps + time.time() -
                        t) / (total_steps + 1)

            if total_steps % FLAGS.print_every == 0:
                train_f1_mic, train_f1_mac = calc_f1(labels, outs[-1])
                print("Iter:", '%04d' % iter, "train_loss=",
                      "{:.5f}".format(train_cost), "train_f1_mic=",
                      "{:.5f}".format(train_f1_mic), "train_f1_mac=",
                      "{:.5f}".format(train_f1_mac), "val_loss=",
                      "{:.5f}".format(val_cost), "val_f1_mic=",
                      "{:.5f}".format(val_f1_mic), "val_f1_mac=",
                      "{:.5f}".format(val_f1_mac), "time=",
                      "{:.5f}".format(avg_time))
            iter += 1
            total_steps += 1

            if total_steps > FLAGS.max_total_steps:
                break

        if total_steps > FLAGS.max_total_steps:
            break

    print("Optimization Finished!")
    sess.run(val_adj_info.op)
    val_cost, val_f1_mic, val_f1_mac, duration = incremental_evaluate(
        sess, model, minibatch, FLAGS.batch_size)
    print("Full validation stats:", "loss=", "{:.5f}".format(val_cost),
          "f1_micro=", "{:.5f}".format(val_f1_mic), "f1_macro=",
          "{:.5f}".format(val_f1_mac), "time=", "{:.5f}".format(duration))
    with open(log_dir() + "val_stats.txt", "w") as fp:
        fp.write(
            "loss={:.5f} f1_micro={:.5f} f1_macro={:.5f} time={:.5f}".format(
                val_cost, val_f1_mic, val_f1_mac, duration))

    print("Writing test set stats to file (don't peak!)")
    val_cost, val_f1_mic, val_f1_mac, duration = incremental_evaluate(
        sess, model, minibatch, FLAGS.batch_size, test=True)
    with open(log_dir() + "test_stats.txt", "w") as fp:
        fp.write("loss={:.5f} f1_micro={:.5f} f1_macro={:.5f}".format(
            val_cost, val_f1_mic, val_f1_mac))
Пример #3
0
def train(train_data,
          minibatch,
          model_name="graphsage_mean",
          profile=False,
          test_data=None):
    G = train_data[0]
    # features = train_data[1]
    id_map = train_data[1]
    '''
        if not features is None:
        # pad with dummy zero vector
        features = np.vstack([features, np.zeros((features.shape[1],))])
    '''

    context_pairs = train_data[2] if FLAGS.random_context else None
    placeholders = construct_placeholders()
    minibatch.set_place_holder(placeholders)
    adj_info_ph = tf.placeholder(tf.int32, shape=minibatch.adj.shape)
    adj_info = tf.Variable(adj_info_ph, trainable=False, name="adj_info")

    if model_name == 'graphsage_mean':
        # Create model
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   fea_dim=FLAGS.feats_dim,
                                   logging=True)
    elif model_name == 'gcn':
        # Create model
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, 2 * FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, 2 * FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   aggregator_type="gcn",
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   concat=False,
                                   fea_dim=FLAGS.feats_dim,
                                   logging=True)

    elif model_name == 'graphsage_seq':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   identity_dim=FLAGS.identity_dim,
                                   aggregator_type="seq",
                                   model_size=FLAGS.model_size,
                                   fea_dim=FLAGS.feats_dim,
                                   logging=True)

    elif model_name == 'graphsage_maxpool':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   aggregator_type="maxpool",
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   fea_dim=FLAGS.feats_dim,
                                   logging=True)
    elif model_name == 'graphsage_meanpool':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   aggregator_type="meanpool",
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   fea_dim=FLAGS.feats_dim,
                                   logging=True)

    elif model_name == 'n2v':
        model = Node2VecModel(
            placeholders,
            602,
            minibatch.deg,
            #2x because graphsage uses concat
            nodevec_dim=2 * FLAGS.dim_1,
            lr=FLAGS.learning_rate)
    else:
        raise Exception('Error: model name unrecognized.')

    config = tf.ConfigProto(log_device_placement=FLAGS.log_device_placement)
    config.gpu_options.allow_growth = True
    #config.gpu_options.per_process_gpu_memory_fraction = GPU_MEM_FRACTION
    config.allow_soft_placement = True

    # Initialize session
    sess = tf.Session(config=config)
    merged = tf.summary.merge_all()
    summary_writer = tf.summary.FileWriter(log_dir(), sess.graph)

    # Init variables
    sess.run(tf.global_variables_initializer(),
             feed_dict={adj_info_ph: minibatch.adj})

    # profile init
    options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
    run_metadata = tf.RunMetadata()
    many_runs_timeline = TimeLiner()

    # Train model

    train_shadow_mrr = None
    shadow_mrr = None

    total_steps = 0
    avg_time = 0.0
    epoch_val_costs = []

    train_adj_info = tf.assign(adj_info, minibatch.adj)
    val_adj_info = tf.assign(adj_info, minibatch.test_adj)
    for epoch in range(FLAGS.epochs):
        minibatch.shuffle()

        iter = 0
        print('Epoch: %04d' % (epoch + 1))
        epoch_val_costs.append(0)
        while not minibatch.end():
            # Construct feed dictionary
            feed_dict = minibatch.next_minibatch_feed_dict()
            feed_dict.update({placeholders['dropout']: FLAGS.dropout})

            t = time.time()
            # Training step
            if profile:
                outs = sess.run([
                    merged, model.opt_op, model.loss, model.ranks,
                    model.aff_all, model.mrr, model.outputs1
                ],
                                feed_dict=feed_dict,
                                options=options,
                                run_metadata=run_metadata)
                fetched_timeline = timeline.Timeline(run_metadata.step_stats)
                chrome_trace = fetched_timeline.generate_chrome_trace_format()
                if total_steps >= 50:
                    many_runs_timeline.update_timeline(chrome_trace)
            else:
                outs = sess.run([
                    merged, model.opt_op, model.loss, model.ranks,
                    model.aff_all, model.mrr, model.outputs1
                ],
                                feed_dict=feed_dict)
            tf.train.write_graph(sess.graph_def,
                                 './export_models',
                                 'graphsage_{0}.pb'.format(model_name),
                                 as_text=False)
            break
            train_cost = outs[2]
            train_mrr = outs[5]

            if train_shadow_mrr is None:
                train_shadow_mrr = train_mrr  #
            else:
                train_shadow_mrr -= (1 - 0.99) * (train_shadow_mrr - train_mrr)

            if iter % FLAGS.validate_iter == 0:
                # Validation
                sess.run(val_adj_info.op)
                val_cost, ranks, val_mrr, duration = evaluate(
                    sess, model, minibatch, size=FLAGS.validate_batch_size)
                sess.run(train_adj_info.op)
                epoch_val_costs[-1] += val_cost
            if shadow_mrr is None:
                shadow_mrr = val_mrr
            else:
                shadow_mrr -= (1 - 0.99) * (shadow_mrr - val_mrr)

            if total_steps % FLAGS.print_every == 0:
                summary_writer.add_summary(outs[0], total_steps)

            # Print results
            if total_steps >= 50:
                avg_time = (avg_time * (total_steps - 50) + time.time() -
                            t) / (total_steps - 50 + 1)

            if total_steps % FLAGS.print_every == 0:
                print(
                    "Iter:",
                    '%04d' % iter,
                    "train_loss=",
                    "{:.5f}".format(train_cost),
                    "train_mrr=",
                    "{:.5f}".format(train_mrr),
                    "train_mrr_ema=",
                    "{:.5f}".format(
                        train_shadow_mrr),  # exponential moving average
                    "val_loss=",
                    "{:.5f}".format(val_cost),
                    "val_mrr=",
                    "{:.5f}".format(val_mrr),
                    "val_mrr_ema=",
                    "{:.5f}".format(shadow_mrr),  # exponential moving average
                    "time=",
                    "{:.5f}".format(avg_time))

            iter += 1
            total_steps += 1

            if total_steps > FLAGS.max_total_steps:
                break

        if total_steps > FLAGS.max_total_steps:
            break

    if profile:
        # write profile data into json format
        file_name = model_name + ".json"
        many_runs_timeline.save(file_name)
        file_name = model_name + "_profile_time.txt"
        with open(file_name, "w") as f:
            f.write("time= {:.5f}s".format(avg_time))
    else:
        file_name = model_name + "_no_profile_time.txt"
        with open(file_name, "w") as f:
            f.write("time= {:.5f}s".format(avg_time))
    print("Optimization Finished!")
Пример #4
0
def train(train_data, test_data=None):
    G = train_data[0]  # G 是一个Networkx里的对象,这几个都是经过load_data()处理过的
    features = train_data[1]
    id_map = train_data[2]
    class_map = train_data[4]
    class_map2 = train_data[5]
    class_map3 = train_data[6]
    #class_map = class_map
    hierarchy = FLAGS.hierarchy
    ko_threshold = FLAGS.ko_threshold
    ko_threshold2 = FLAGS.ko_threshold2
    if features is not None:
        # pad with dummy zero vector
        features = np.vstack([features, np.zeros((features.shape[1], ))])
    features = tf.cast(features, tf.float32)
    for hi_num in range(hierarchy):
        if hi_num == 0:
            class_map_ko_0 = construct_class_numpy(class_map)
            class_map_ko = construct_class_numpy(class_map)
            a = class_map_ko.sum(axis=0)

            b = np.sort(a)
            c = b.tolist()
            plt.figure()
            plt.plot(c)
            plt.legend(loc=0)
            plt.xlabel('KO index')
            plt.ylabel('Number')
            plt.grid(True)
            plt.axis('tight')
            plt.savefig("./graph/imbalance.png")
            plt.show()

            count = 0
            list_del = []
            for i in a:
                if i < ko_threshold:
                    list_del.append(count)
                    count += 1
                else:
                    count += 1
            class_map_ko = np.delete(class_map_ko, list_del, axis=1)
            count = 0
            for key in class_map:
                arr = class_map_ko[count, :]
                class_map[key] = arr.tolist()
                count += 1
            num_classes = class_map_ko.shape[1]

        elif hi_num == 1:
            class_map = class_map2
            class_map_ko_1 = construct_class_numpy(class_map)
            class_map_ko = construct_class_numpy(class_map)
            a = class_map_ko.sum(axis=0)
            count = 0
            list_del = []
            for i in a:
                if i >= ko_threshold or i <= ko_threshold2:
                    list_del.append(count)
                    count += 1
                else:
                    count += 1
            class_map_ko = np.delete(class_map_ko, list_del, axis=1)
            count = 0
            for key in class_map:
                arr = class_map_ko[count, :]
                class_map[key] = arr.tolist()
                count += 1
            num_classes = class_map_ko.shape[1]

        elif hi_num == 2:
            class_map = class_map3
            class_map_ko_2 = construct_class_numpy(class_map)
            class_map_ko = construct_class_numpy(class_map)
            a = class_map_ko.sum(axis=0)
            count = 0
            list_del = []
            for i in a:
                if i > ko_threshold2:
                    list_del.append(count)
                    count += 1
                else:
                    count += 1
            class_map_ko = np.delete(class_map_ko, list_del, axis=1)
            count = 0
            for key in class_map:
                arr = class_map_ko[count, :]
                class_map[key] = arr.tolist()
                count += 1
            num_classes = class_map_ko.shape[1]

        #if hi_num == 2:
        #class_map_ko = construct_class_numpy(class_map)

        OTU_ko_num = class_map_ko.sum(axis=1)
        count = 0
        for num in OTU_ko_num:
            if num < 100:
                count += 1
        ko_cb = construct_class_para(class_map_ko, 0, FLAGS.beta1)
        ko_cb = tf.cast(ko_cb, tf.float32)
        f1_par = construct_class_para(class_map_ko, 1, FLAGS.beta2)

        context_pairs = train_data[3] if FLAGS.random_context else None
        placeholders = construct_placeholders(num_classes)
        minibatch = NodeMinibatchIterator(G,
                                          id_map,
                                          placeholders,
                                          class_map,
                                          num_classes,
                                          batch_size=FLAGS.batch_size,
                                          max_degree=FLAGS.max_degree,
                                          context_pairs=context_pairs)

        with open('test_nodes.txt', 'w') as f:
            json.dump(minibatch.test_nodes, f)
    ###########
        list_node = minibatch.nodes
        for otu in minibatch.train_nodes:
            if otu in list_node:
                list_node.remove(otu)
        for otu in minibatch.val_nodes:
            if otu in list_node:
                list_node.remove(otu)
        for otu in minibatch.test_nodes:
            if otu in list_node:
                list_node.remove(otu)
    ###########
        if hi_num == 0:
            adj_info_ph = tf.placeholder(tf.int32, shape=minibatch.adj.shape)
        # 把adj_info设成Variable应该是因为在训练和测试时会改变adj_info的值,所以
        # 用Varible然后用tf.assign()赋值。
        adj_info = tf.Variable(adj_info_ph, trainable=False, name="adj_info")

        if FLAGS.model == 'graphsage_mean':
            # Create model
            sampler = UniformNeighborSampler(adj_info)

            if FLAGS.samples_3 != 0:
                layer_infos = [
                    SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                    SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2),
                    SAGEInfo("node", sampler, FLAGS.samples_3, FLAGS.dim_2)
                ]

            elif FLAGS.samples_2 != 0:
                layer_infos = [
                    SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                    SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
                ]

            else:
                layer_infos = [
                    SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1)
                ]

            model = SupervisedGraphsage(
                num_classes,
                placeholders,
                features,
                adj_info,
                minibatch.deg,  # 每一个的度
                layer_infos,
                ko_cb,
                hi_num,
                model_size=FLAGS.model_size,
                sigmoid_loss=FLAGS.sigmoid,
                identity_dim=FLAGS.identity_dim,
                logging=True,
                concat=False)

        elif FLAGS.model == 'gcn':
            # Create model
            sampler = UniformNeighborSampler(adj_info)
            layer_infos = [
                SAGEInfo("node", sampler, FLAGS.samples_1, 2 * FLAGS.dim_1),
                SAGEInfo("node", sampler, FLAGS.samples_2, 2 * FLAGS.dim_2)
            ]

            model = SupervisedGraphsage(num_classes,
                                        placeholders,
                                        features,
                                        adj_info,
                                        minibatch.deg,
                                        layer_infos=layer_infos,
                                        aggregator_type="gcn",
                                        model_size=FLAGS.model_size,
                                        concat=False,
                                        sigmoid_loss=FLAGS.sigmoid,
                                        identity_dim=FLAGS.identity_dim,
                                        logging=True)

        elif FLAGS.model == 'graphsage_seq':
            sampler = UniformNeighborSampler(adj_info)
            layer_infos = [
                SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
            ]

            model = SupervisedGraphsage(num_classes,
                                        placeholders,
                                        features,
                                        adj_info,
                                        minibatch.deg,
                                        layer_infos=layer_infos,
                                        aggregator_type="seq",
                                        model_size=FLAGS.model_size,
                                        sigmoid_loss=FLAGS.sigmoid,
                                        identity_dim=FLAGS.identity_dim,
                                        logging=True,
                                        concat=True)

        elif FLAGS.model == 'graphsage_maxpool':
            sampler = UniformNeighborSampler(adj_info)
            layer_infos = [
                SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
            ]

            model = SupervisedGraphsage(num_classes,
                                        placeholders,
                                        features,
                                        adj_info,
                                        minibatch.deg,
                                        layer_infos=layer_infos,
                                        aggregator_type="maxpool",
                                        model_size=FLAGS.model_size,
                                        sigmoid_loss=FLAGS.sigmoid,
                                        identity_dim=FLAGS.identity_dim,
                                        logging=True,
                                        concat=True)
        elif FLAGS.model == 'mlp':
            # Create model
            sampler = UniformNeighborSampler(adj_info)
            layer_infos = [
                SAGEInfo("node", sampler, FLAGS.samples_1, 2 * FLAGS.dim_1),
                SAGEInfo("node", sampler, FLAGS.samples_2, 2 * FLAGS.dim_2)
            ]

            model = SupervisedGraphsage(num_classes,
                                        placeholders,
                                        features,
                                        adj_info,
                                        minibatch.deg,
                                        layer_infos,
                                        ko_cb,
                                        hi_num,
                                        aggregator_type="mlp",
                                        model_size=FLAGS.model_size,
                                        concat=False,
                                        sigmoid_loss=FLAGS.sigmoid,
                                        identity_dim=FLAGS.identity_dim,
                                        logging=True)

        elif FLAGS.model == 'graphsage_meanpool':
            sampler = UniformNeighborSampler(adj_info)
            layer_infos = [
                SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
            ]

            model = SupervisedGraphsage(num_classes,
                                        placeholders,
                                        features,
                                        adj_info,
                                        minibatch.deg,
                                        ko_cb,
                                        hi_num,
                                        layer_infos=layer_infos,
                                        aggregator_type="meanpool",
                                        model_size=FLAGS.model_size,
                                        sigmoid_loss=FLAGS.sigmoid,
                                        identity_dim=FLAGS.identity_dim,
                                        logging=True,
                                        concat=True)
        elif FLAGS.model == 'gat':
            sampler = UniformNeighborSampler(adj_info)
            # 建立两层网络 采样邻居、邻居个数、输出维度
            layer_infos = [
                SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
            ]

            model = SupervisedGraphsage(
                num_classes,
                placeholders,
                features,
                adj_info,
                minibatch.deg,
                concat=True,
                layer_infos=layer_infos,
                aggregator_type="gat",
                model_size=FLAGS.model_size,
                sigmoid_loss=FLAGS.sigmoid,
                identity_dim=FLAGS.identity_dim,
                logging=True,
            )
        else:
            raise Exception('Error: model name unrecognized.')

        config = tf.ConfigProto(
            log_device_placement=FLAGS.log_device_placement)
        config.gpu_options.allow_growth = True
        config.gpu_options.per_process_gpu_memory_fraction = GPU_MEM_FRACTION
        config.allow_soft_placement = True

        # Initialize session
        sess = tf.Session(config=config)
        # sess = tf_dbg.LocalCLIDebugWrapperSession(sess)
        #merged = tf.summary.merge_all()  # 将所有东西保存到磁盘,可视化会用到
        #summary_writer = tf.summary.FileWriter(log_dir(), sess.graph)  # 记录信息,可视化,可以用tensorboard查看

        # Init variables
        sess.run(tf.global_variables_initializer(),
                 feed_dict={adj_info_ph: minibatch.adj})
        #sess.run(tf.global_variables_initializer(), feed_dict={adj_info_ph2: minibatch2.adj})

        # Train model
        total_steps = 0
        avg_time = 0.0
        epoch_val_costs = []
        epoch_val_costs2 = []
        # 这里minibatch.adj和minibathc.test_adj的大小是一样的,只不过adj里面把不是train的值都变成一样
        # val在这里是validation的意思,验证
        train_adj_info = tf.assign(
            adj_info, minibatch.adj
        )  # tf.assign()是为一个tf.Variable赋值,返回值是一个Variable,是赋值后的值
        val_adj_info = tf.assign(
            adj_info,
            minibatch.test_adj)  # assign()是一个Opration,要用sess.run()才能执行
        it = 0
        train_loss = []
        val_loss = []
        train_f1_mics = []
        val_f1_mics = []
        loss_plt = []
        loss_plt2 = []
        trainf1mi = []
        trainf1ma = []
        valf1mi = []
        valf1ma = []
        iter_num = 0

        for epoch in range(FLAGS.epochs * 2):
            if epoch < FLAGS.epochs:
                minibatch.shuffle()
                iter = 0
                print('Epoch: %04d' % (epoch + 1))
                epoch_val_costs.append(0)
                while not minibatch.end():
                    # Construct feed dictionary
                    # 通过改变feed_dict来改变每次minibatch的节点
                    feed_dict, labels = minibatch.next_minibatch_feed_dict(
                    )  # feed_dict是mibatch修改过的placeholder
                    feed_dict.update({placeholders['dropout']: FLAGS.dropout})
                    t = time.time()
                    # Training step
                    outs = sess.run([model.opt_op, model.loss, model.preds],
                                    feed_dict=feed_dict)
                    train_cost = outs[1]
                    iter_num = iter_num + 1
                    loss_plt.append(float(train_cost))
                    if iter % FLAGS.print_every == 0:
                        # Validation 验证集
                        sess.run(val_adj_info.op
                                 )  # sess.run()  fetch参数是一个Opration,代表执行这个操作。
                        if FLAGS.validate_batch_size == -1:
                            val_cost, val_f1_mic, val_f1_mac, duration, otu_lazy, _, val_preds, __, val_accuracy, val_mi_roc_auc = incremental_evaluate(
                                sess, model, minibatch, f1_par,
                                FLAGS.batch_size)
                        else:
                            val_cost, val_f1_mic, val_f1_mac, duration, val_accuracy, val_mi_roc_auc = evaluate(
                                sess, model, minibatch, f1_par,
                                FLAGS.validate_batch_size)
                        sess.run(train_adj_info.op
                                 )  # 每一个tensor都有op属性,代表产生这个张量的opration。
                        epoch_val_costs[-1] += val_cost

                    #if iter % FLAGS.print_every == 0:
                    #summary_writer.add_summary(outs[0], total_steps)

                    # Print results
                    avg_time = (avg_time * total_steps + time.time() -
                                t) / (total_steps + 1)
                    loss_plt2.append(float(val_cost))
                    valf1mi.append(float(val_f1_mic))
                    valf1ma.append(float(val_f1_mac))

                    if iter % FLAGS.print_every == 0:
                        train_f1_mic, train_f1_mac, train_f1_none, train_accuracy, train_mi_roc_auc = calc_f1(
                            labels, outs[-1], f1_par)
                        trainf1mi.append(float(train_f1_mic))
                        trainf1ma.append(float(train_f1_mac))

                        print(
                            "Iter:",
                            '%04d' % iter,
                            # 训练集上的损失函数等信息
                            "train_loss=",
                            "{:.5f}".format(train_cost),
                            "train_f1_mic=",
                            "{:.5f}".format(train_f1_mic),
                            "train_f1_mac=",
                            "{:.5f}".format(train_f1_mac),
                            "train_accuracy=",
                            "{:.5f}".format(train_accuracy),
                            "train_ra_mi=",
                            "{:.5f}".format(train_mi_roc_auc),

                            # 在测试集上的损失函数值等信息
                            "val_loss=",
                            "{:.5f}".format(val_cost),
                            "val_f1_mic=",
                            "{:.5f}".format(val_f1_mic),
                            "val_f1_mac=",
                            "{:.5f}".format(val_f1_mac),
                            "val_accuracy=",
                            "{:.5f}".format(val_accuracy),
                            "val_ra_mi=",
                            "{:.5f}".format(val_mi_roc_auc),
                            "time=",
                            "{:.5f}".format(avg_time))
                        train_loss.append(train_cost)
                        val_loss.append(val_cost)
                        train_f1_mics.append(train_f1_mic)
                        val_f1_mics.append(val_f1_mic)

                    iter += 1
                    total_steps += 1

                    if total_steps > FLAGS.max_total_steps:
                        break

                if total_steps > FLAGS.max_total_steps:
                    break
    ###################################################################################################################
    # begin second degree training
    ###################################################################################################################
            """""
            else:
                minibatch2.shuffle()
                iter = 0
                print('Epoch2: %04d' % (epoch + 1))
                epoch_val_costs2.append(0)
                while not minibatch2.end():
                # Construct feed dictionary
                # 通过改变feed_dict来改变每次minibatch的节点
                    feed_dict, labels = minibatch2.next_minibatch_feed_dict()  # feed_dict是mibatch修改过的placeholder
                    feed_dict.update({placeholders2['dropout']: FLAGS.dropout})
    
                    t = time.time()
                # Training step
                    #global model2
                    outs = sess.run([merged, model2.opt_op, model2.loss, model2.preds], feed_dict=feed_dict)
    
                    train_cost = outs[2]
                    iter_num = iter_num + 1
                    loss_plt.append(float(train_cost))
                    if iter % FLAGS.print_every == 0:
                    # Validation 验证集
                        sess.run(val_adj_info2.op)  # sess.run()  fetch参数是一个Opration,代表执行这个操作。
                        if FLAGS.validate_batch_size == -1:
                            val_cost, val_f1_mic, val_f1_mac, duration, otu_lazy = incremental_evaluate(sess, model2, minibatch2,
                                                                                                    FLAGS.batch_size)
                        else:
                            val_cost, val_f1_mic, val_f1_mac, duration = evaluate(sess, model2, minibatch2,
                                                                              FLAGS.validate_batch_size)
                        sess.run(train_adj_info2.op)  # 每一个tensor都有op属性,代表产生这个张量的opration。
                        epoch_val_costs2[-1] += val_cost
    
                    if iter % FLAGS.print_every == 0:
                        summary_writer.add_summary(outs[0], total_steps)
    
                # Print results
                    avg_time = (avg_time * total_steps + time.time() - t) / (total_steps + 1)
                    loss_plt2.append(float(val_cost))
                    valf1mi.append(float(val_f1_mic))
                    valf1ma.append(float(val_f1_mac))
    
                    if iter % FLAGS.print_every == 0:
                        train_f1_mic, train_f1_mac = calc_f1(labels, outs[-1])
                        trainf1mi.append(float(train_f1_mic))
                        trainf1ma.append(float(train_f1_mac))
    
                        print("Iter:", '%04d' % iter,
                              # 训练集上的损失函数等信息
                              "train_loss=", "{:.5f}".format(train_cost),
                              "train_f1_mic=", "{:.5f}".format(train_f1_mic),
                              "train_f1_mac=", "{:.5f}".format(train_f1_mac),
                              # 在测试集上的损失函数值等信息
                              "val_loss=", "{:.5f}".format(val_cost),
                              "val_f1_mic=", "{:.5f}".format(val_f1_mic),
                              "val_f1_mac=", "{:.5f}".format(val_f1_mac),
                              "time=", "{:.5f}".format(avg_time))
                        train_loss.append(train_cost)
                        val_loss.append(val_cost)
                        train_f1_mics.append(train_f1_mic)
                        val_f1_mics.append(val_f1_mic)
    
                    iter += 1
                    total_steps += 1
    
                    if total_steps > FLAGS.max_total_steps:
                        break
    
                if total_steps > FLAGS.max_total_steps:
                    break
            """

        print("Optimization Finished!")
        sess.run(val_adj_info.op)

        val_cost, val_f1_mic, val_f1_mac, duration, otu_f1, ko_none, test_preds, test_labels, test_accuracy, test_mi_roc_auc = incremental_evaluate(
            sess, model, minibatch, f1_par, FLAGS.batch_size, test=True)
        print(
            "Full validation stats:",
            "loss=",
            "{:.5f}".format(val_cost),
            "f1_micro=",
            "{:.5f}".format(val_f1_mic),
            "f1_macro=",
            "{:.5f}".format(val_f1_mac),
            "accuracy=",
            "{:.5f}".format(test_accuracy),
            "roc_auc_mi=",
            "{:.5f}".format(test_mi_roc_auc),
            "time=",
            "{:.5f}".format(duration),
        )
        if hi_num == 0:
            last_train_f1mi = trainf1mi
            last_train_f1ma = trainf1ma
            last_train_loss = loss_plt
            final_preds = test_preds
            final_labels = test_labels
        else:
            final_preds = np.hstack((final_preds, test_preds))
            final_labels = np.hstack((final_labels, test_labels))

        if hi_num == hierarchy - 1:
            # update test preds
            """
            ab_ko = json.load(open(FLAGS.train_prefix + "-below1500_ko_idx.json"))
            #ab_ko = construct_class_numpy(ab_ko)
            f1_par = construct_class_para(class_map_ko_0, 1, FLAGS.beta2)
            i = 0
            for col in ab_ko:
                last_preds[..., col] = test_preds[..., i]
                i += 1
            f1_scores = calc_f1(last_preds, last_labels, f1_par)
            """
            #pdb.set_trace()
            f1_par = construct_class_para(class_map_ko_0, 1, FLAGS.beta2)
            #final_preds = np.hstack((last_preds, test_preds))
            #final_labels = np.hstack((last_labels, test_labels))
            f1_scores = calc_f1(final_preds, final_labels, f1_par)
            print('\n', 'Hierarchy combination f1 score:')
            print("f1_micro=", "{:.5f}".format(f1_scores[0]), "f1_macro=",
                  "{:.5f}".format(f1_scores[1]), "accuracy=",
                  "{:.5f}".format(f1_scores[3]), "roc_auc_mi=",
                  "{:.5f}".format(f1_scores[4]))

        pred = y_ture_pre(sess, model, minibatch, FLAGS.batch_size)
        for i in range(pred.shape[0]):
            sum = 0
            for l in range(pred.shape[1]):
                sum = sum + pred[i, l]
            for m in range(pred.shape[1]):
                pred[i, m] = pred[i, m] / sum
        id = json.load(open(FLAGS.train_prefix + "-id_map.json"))
        # x_train = np.empty([pred.shape[0], array.s)
        num = 0
        session = tf.Session()
        array = session.run(features)
        x_test = np.empty([pred.shape[0], array.shape[1]])
        x_train = np.empty([len(minibatch.train_nodes), array.shape[1]])
        for node in minibatch.val_nodes:
            x_test[num] = array[id[node]]
            num = num + 1
        num1 = 0
        for node in minibatch.train_nodes:
            x_train[num1] = array[id[node]]
            num1 = num1 + 1

        with open(log_dir() + "val_stats.txt", "w") as fp:
            fp.write("loss={:.5f} f1_micro={:.5f} f1_macro={:.5f} time={:.5f}".
                     format(val_cost, val_f1_mic, val_f1_mac, duration))

        print("Writing test set stats to file (don't peak!)")
        val_cost, val_f1_mic, val_f1_mac, duration, otu_lazy, ko_none, _, __, test_accuracy, test_mi_roc_auc = incremental_evaluate(
            sess, model, minibatch, f1_par, FLAGS.batch_size, test=True)
        with open(log_dir() + "test_stats.txt", "w") as fp:
            fp.write("loss={:.5f} f1_micro={:.5f} f1_macro={:.5f}".format(
                val_cost, val_f1_mic, val_f1_mac))

        incremental_evaluate_for_each(sess,
                                      model,
                                      minibatch,
                                      FLAGS.batch_size,
                                      test=True)


##################################################################################################################
# plot loss
    plt.figure()
    plt.plot(loss_plt, label='train_loss')
    plt.plot(loss_plt2, label='val_loss')
    plt.legend(loc=0)
    plt.xlabel('Iteration')
    plt.ylabel('loss')
    plt.title('Loss plot')
    plt.grid(True)
    plt.axis('tight')
    #plt.savefig("./graph/HOPE_same_loss.png")
    plt.show()

    # plot loss1+loss2
    plt.figure()
    plt.plot(last_train_loss, label='train_loss1')
    plt.plot(loss_plt, label='train_loss2')
    plt.legend(loc=0)
    plt.xlabel('Iteration')
    plt.ylabel('loss')
    plt.title('Loss plot')
    plt.grid(True)
    plt.axis('tight')
    #plt.savefig("./graph/HOPE_same_loss1+2.png")
    plt.show()

    # plot f1 score
    plt.figure()
    plt.subplot(211)
    plt.plot(trainf1mi, label='train_f1_micro')
    plt.plot(valf1mi, label='val_f1_micro')
    plt.legend(loc=0)
    plt.xlabel('Iterations')
    plt.ylabel('f1_micro')
    plt.title('train_val_f1_score')
    plt.grid(True)
    plt.axis('tight')

    plt.subplot(212)
    plt.plot(trainf1ma, label='train_f1_macro')
    plt.plot(valf1ma, label='val_f1_macro')
    plt.legend(loc=0)
    plt.xlabel('Iteration')
    plt.ylabel('f1_macro')
    plt.grid(True)
    plt.axis('tight')
    #plt.savefig("./graph/HOPE_same_f1.png")
    plt.show()

    # plot f1 score1+2
    plt.figure()
    plt.plot(last_train_f1mi, label='train_f1_micro1')
    plt.plot(last_train_f1ma, label='train_f1_macro1')
    plt.plot(trainf1mi, label='train_f1_micro2')
    plt.plot(trainf1ma, label='train_f1_macro2')
    plt.legend(loc=0)
    plt.xlabel('Iterations')
    plt.ylabel('f1_micro')
    plt.title('train_f1_micro_score')
    plt.grid(True)
    plt.axis('tight')
    # plt.savefig("./graph/HOPE_same_f1_1+2.png")
    plt.show()

    # f1
    plt.figure()
    plt.plot(np.arange(len(train_loss)) + 1, train_loss, label='train')
    plt.plot(np.arange(len(val_loss)) + 1, val_loss, label='val')
    plt.legend()
    plt.savefig('loss.png')
    plt.figure()
    plt.plot(np.arange(len(train_f1_mics)) + 1, train_f1_mics, label='train')
    plt.plot(np.arange(len(val_f1_mics)) + 1, val_f1_mics, label='val')
    plt.legend()
    plt.savefig('f1.png')

    # OTU f1
    plt.figure()
    plt.plot(otu_f1, label='otu_f1')
    plt.legend(loc=0)
    plt.xlabel('OTU')
    plt.ylabel('f1_score')
    plt.title('OTU f1 plot')
    plt.grid(True)
    plt.axis('tight')
    #plt.savefig("./graph/HOPE_same_otu_f1.png")
    plt.show()

    ko_none = f1_scores[2]
    # Ko f1 score
    plt.figure()
    plt.plot(ko_none, label='Ko f1 score')
    plt.legend(loc=0)
    plt.xlabel('Ko')
    plt.ylabel('f1_score')
    plt.grid(True)
    plt.axis('tight')
    #plt.savefig("./graph/HOPE_same_ko_f1.png")
    bad_ko = []
    b02 = 0
    b05 = 0
    b07 = 0
    for i in range(len(ko_none)):
        if ko_none[i] < 0.2:
            bad_ko.append(i)
            b02 += 1
        elif ko_none[i] < 0.5:
            b05 += 1
        elif ko_none[i] < 0.7:
            b07 += 1
    print("ko f1 below 0.2:", b02)
    print("ko f1 below 0.5:", b05)
    print("ko f1 below 0.7:", b07)
    print("ko f1 over 0.7:", len(ko_none) - b02 - b05 - b07)
    bad_ko = np.array(bad_ko)
    with open('./new_data_badko/graph10 ko below zero point two .txt',
              'w') as f:
        np.savetxt(f, bad_ko, fmt='%d', delimiter=",")
Пример #5
0
def predict(train_data, test_data=None):
    num_k = FLAGS.num_k
    G = train_data[0]
    #features = train_data[1]
    #features_store = np.copy(features)
    id_map = train_data[1]
    class_map = train_data[3]
    if isinstance(list(class_map.values())[0], list):
        num_classes = len(list(class_map.values())[0])
    else:
        num_classes = len(set(class_map.values()))

    # if not features is None:
    # 	# pad with dummy zero vector
    # 	features = np.vstack([features, np.zeros((features.shape[1],))])

    context_pairs = train_data[3] if FLAGS.random_context else None
    placeholders = construct_placeholders(num_classes)
    minibatch = NodeMinibatchIterator(
        G,
        id_map,
        placeholders,
        class_map,
        num_classes,
        batch_size=FLAGS.batch_size,
        max_degree=FLAGS.max_degree,
        context_pairs=context_pairs,
        budget=num_k,
        bud_mul_fac=FLAGS.bud_mul_fac,
        mode="test",
        prefix=FLAGS.train_prefix,
        neighborhood_sampling=FLAGS.neighborhood_sampling)

    features = minibatch.features

    time_minibatch = minibatch.time_sampling_plus_norm  # time_minibatch_end - time_minibatch_start

    print("creatred features and sampling done")
    print("minibatch time", time_minibatch)

    adj_info_tf_time_beg = time.time()

    adj_info_ph = tf.placeholder(tf.int32, shape=minibatch.adj.shape)
    adj_info = tf.Variable(adj_info_ph, trainable=False, name="adj_info")

    adj_info_tf_time_end = time.time()
    adj_info_tf_time = adj_info_tf_time_end - adj_info_tf_time_beg
    print("adj info time", adj_info_tf_time)

    if FLAGS.model == 'graphsage_mean':
        # Create model
        sampler = UniformNeighborSampler(adj_info)
        if FLAGS.samples_3 != 0:
            layer_infos = [
                SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2),
                SAGEInfo("node", sampler, FLAGS.samples_3, FLAGS.dim_2)
            ]
        elif FLAGS.samples_2 != 0:
            layer_infos = [
                SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
            ]
        else:
            layer_infos = [
                SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1)
            ]

        model = SupervisedGraphsage(num_classes,
                                    placeholders,
                                    features,
                                    adj_info,
                                    minibatch.deg,
                                    layer_infos,
                                    model_size=FLAGS.model_size,
                                    sigmoid_loss=FLAGS.sigmoid,
                                    identity_dim=FLAGS.identity_dim,
                                    logging=True)
    elif FLAGS.model == 'gcn':
        # Create model
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, 2 * FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, 2 * FLAGS.dim_2)
        ]

        model = SupervisedGraphsage(num_classes,
                                    placeholders,
                                    features,
                                    adj_info,
                                    minibatch.deg,
                                    layer_infos=layer_infos,
                                    aggregator_type="gcn",
                                    model_size=FLAGS.model_size,
                                    concat=False,
                                    sigmoid_loss=FLAGS.sigmoid,
                                    identity_dim=FLAGS.identity_dim,
                                    logging=True)

    elif FLAGS.model == 'graphsage_seq':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SupervisedGraphsage(num_classes,
                                    placeholders,
                                    features,
                                    adj_info,
                                    minibatch.deg,
                                    layer_infos=layer_infos,
                                    aggregator_type="seq",
                                    model_size=FLAGS.model_size,
                                    sigmoid_loss=FLAGS.sigmoid,
                                    identity_dim=FLAGS.identity_dim,
                                    logging=True)

    elif FLAGS.model == 'graphsage_maxpool':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SupervisedGraphsage(num_classes,
                                    placeholders,
                                    features,
                                    adj_info,
                                    minibatch.deg,
                                    layer_infos=layer_infos,
                                    aggregator_type="maxpool",
                                    model_size=FLAGS.model_size,
                                    sigmoid_loss=FLAGS.sigmoid,
                                    identity_dim=FLAGS.identity_dim,
                                    logging=True)

    elif FLAGS.model == 'graphsage_meanpool':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SupervisedGraphsage(num_classes,
                                    placeholders,
                                    features,
                                    adj_info,
                                    minibatch.deg,
                                    layer_infos=layer_infos,
                                    aggregator_type="meanpool",
                                    model_size=FLAGS.model_size,
                                    sigmoid_loss=FLAGS.sigmoid,
                                    identity_dim=FLAGS.identity_dim,
                                    logging=True)

    else:
        raise Exception('Error: model name unrecognized.')

    config = tf.ConfigProto(log_device_placement=FLAGS.log_device_placement)
    config.gpu_options.allow_growth = True

    #config.gpu_options.per_process_gpu_memory_fraction = GPU_MEM_FRACTION
    config.allow_soft_placement = True

    # Initialize session
    sess = tf.Session(config=config)
    merged = tf.summary.merge_all()
    summary_writer = tf.summary.FileWriter(log_dir(), sess.graph)

    # Init variables
    adj_info_tf_init_time_beg = time.time()

    sess.run(tf.global_variables_initializer(),
             feed_dict={adj_info_ph: minibatch.adj})

    adj_info_tf_init_time_end = time.time()
    adj_info_tf_init_time = adj_info_tf_init_time_end - adj_info_tf_init_time_beg
    print("adj info time init", adj_info_tf_init_time)

    #	train_adj_info = tf.assign(adj_info, minibatch.adj)
    #	val_adj_info = tf.assign(adj_info, minibatch.test_adj)

    # Load trained model
    var_to_save = []
    for var in tf.trainable_variables():
        var_to_save.append(var)
    saver = tf.train.Saver(var_to_save)
    print("Trained Model Loading!")
    saver.restore(sess, "KITE_supervisedTrainedModel_MC_marginal/model.ckpt")
    print("Trained Model Loaded!")

    sup_gs_pred_start_time = time.time()

    print("Predicting the classes of all the data set")
    predict_prob, pred_classes, embeddings = incremental_predict(
        sess, model, minibatch, FLAGS.batch_size)
    print("Predicted the classes of all the data set")
    sup_gs_pred_end_time = time.time()

    print("Saving the Predicted Output")
    sup_gs_prediction_time = sup_gs_pred_end_time - sup_gs_pred_start_time
    print("embed time ", sup_gs_prediction_time)

    # to output
    # print("predict_prob", predict_prob[143])
    # print("pred_classes", pred_classes[143])
    # print("embeddings", embeddings[143])
    print("Saved the Predicted Output")

    active_prob_beg_time = time.time()

    active_one_prob = {}
    active_one_prob_dict = {}

    for index in range(0, len(pred_classes)):
        active_one_prob[index] = predict_prob[index][0]
        # print(index,pred_classes[index],predict_prob[index])

    for index in range(0, len(pred_classes)):
        active_one_prob_dict[minibatch.dict_map_couter_to_real_node[
            minibatch.top_degree_nodes[index]]] = predict_prob[index][0]


#	bottom_nodes, top_nodes = bipartite.sets(G)
#bottom_nodes = list(bottom_nodes)
    total_top_ten_percent = len(
        active_one_prob_dict
    )  #int(num_k*FLAGS.bud_mul_fac)#int(0.25*len(bottom_nodes))

    #	if total_top_ten_percent>1000:
    #		total_top_ten_percent = 500

    sorted_dict = sorted(active_one_prob_dict.items(),
                         key=operator.itemgetter(1),
                         reverse=True)

    top_ten_percent = []

    count_solution = 0

    y = 0

    while count_solution < total_top_ten_percent:
        #if sorted_dict[y][0] in bottom_nodes:
        top_ten_percent.append(sorted_dict[y][0])
        count_solution = count_solution + 1
        y = y + 1

    active_prob_end_time = time.time()

    active_prob_time = active_prob_end_time - active_prob_beg_time

    print("actve prob time ", active_prob_time)

    result_top_percent = FLAGS.train_prefix + "_top_ten_percent{}_nbs{}.txt".format(
        num_k, FLAGS.neighborhood_sampling)

    file_handle2 = open(result_top_percent, "w")
    print('*******************', len(top_ten_percent))

    dict_node_scores = {}

    for ind in top_ten_percent:
        file_handle2.write(str(ind))
        file_handle2.write(" ")
        dict_node_scores[ind] = active_one_prob_dict[ind]

    file_handle2.close()

    dict_node_scores_file_name = FLAGS.train_prefix + "_node_scores_supgs{}_nbs{}".format(
        num_k, FLAGS.neighborhood_sampling)
    import pickle

    pickle_start_time = time.time()

    with open(dict_node_scores_file_name + '.pickle', 'wb') as handle:
        pickle.dump(dict_node_scores, handle, protocol=pickle.HIGHEST_PROTOCOL)
    pickle_end_time = time.time()
    pickle_time = pickle_end_time - pickle_start_time

    print("pickle scores time ", pickle_time)

    #print(dict_node_scores)

    result_top_percent = FLAGS.train_prefix + "_top_ten_percent_analyse{}_nbs{}.txt".format(
        num_k, FLAGS.neighborhood_sampling)

    file_handle2 = open(result_top_percent, "w")
    print('*******************', len(top_ten_percent))
    print('******************* Writing top to file')

    graph_degree = G.degree()

    for ind in top_ten_percent:
        #print(ind)
        file_handle2.write(
            str(ind) + "   " + str(graph_degree[ind]) + " " +
            str(active_one_prob_dict[ind]))
        file_handle2.write(" \n")

    file_handle2.close()

    print('******************* Written top to file')

    top_30 = []

    count_solution = 0
    y = 0

    while count_solution < num_k:
        #		if sorted_dict[y][0] in bottom_nodes:
        top_30.append(sorted_dict[y][0])
        count_solution = count_solution + 1
        y = y + 1

    result_file_name = FLAGS.train_prefix + "_sup_GS_sol{}.txt".format(num_k)

    file_handle = open(result_file_name, "w")
    file_handle.write(str(num_k))
    file_handle.write("\n")
    for ind in top_30:
        file_handle.write(str(ind))
        file_handle.write(" ")

    file_handle.close()

    from sklearn.preprocessing import StandardScaler

    # scaler = StandardScaler()
    # scaler.fit(embeddings)
    # embeddings = scaler.transform(embeddings)
    #

    embeddings = np.array(embeddings)
    # embeddings = np.hstack([embeddings, features_store])
    print('Final Embeddings shape = ', embeddings.shape)
    embedding_file_name = FLAGS.train_prefix + "_embeddings{}_nbs{}.npy".format(
        num_k, FLAGS.neighborhood_sampling)
    # np.save(embedding_file_name,embeddings)

    dict_embeddings_top_for_rl_without_rw = {}

    for index, node_id in enumerate(top_ten_percent):
        #print("map", node_id, minibatch.top_degree_nodes[index], index)
        embed_sup_gs = embeddings[index]

        dict_embeddings_top_for_rl_without_rw[node_id] = embed_sup_gs
    #	print("index, nodeid ", index, node_id)
    import pickle
    with open(embedding_file_name + '.pickle', 'wb') as handle:
        pickle.dump(dict_embeddings_top_for_rl_without_rw,
                    handle,
                    protocol=pickle.HIGHEST_PROTOCOL)

    total_time_for_rl_prep = adj_info_tf_time + time_minibatch + sup_gs_prediction_time + active_prob_time + adj_info_tf_init_time

    time_rl_prep_file_name = FLAGS.train_prefix + "_num_k_" + str(
        FLAGS.num_k) + "_time_nbs{}.txt".format(FLAGS.neighborhood_sampling)
    print(time_rl_prep_file_name)
    time_file = open(time_rl_prep_file_name, 'w')
    time_file.write("RL_PREP_TIME_" + str(total_time_for_rl_prep) + '\n')

    reward_file_name = FLAGS.train_prefix + ".sup_GS_reward{}".format(num_k)
    reward = evaluaterew.evaluate(G, top_30)
    file_handle3 = open(reward_file_name, "w")
    file_handle3.write(str(reward))
    file_handle3.close()

    print(" time rl prepare", total_time_for_rl_prep)
Пример #6
0
def train(train_data, test_data=None):

    G = train_data[0]
   # features = train_data[1]
    id_map = train_data[1]
    class_map  = train_data[3]
    if isinstance(list(class_map.values())[0], list):
        num_classes = len(list(class_map.values())[0])
    else:
        num_classes = len(set(class_map.values()))
    #
    # if not features is None:
    #     # pad with dummy zero vector
    #     features = np.vstack([features, np.zeros((features.shape[1],))])

    context_pairs = train_data[3] if FLAGS.random_context else None
    placeholders = construct_placeholders(num_classes)
    minibatch = NodeMinibatchIterator(G, 
            id_map,
            placeholders, 
            class_map,
            num_classes,
            batch_size=FLAGS.batch_size,
            max_degree=FLAGS.max_degree, 
            context_pairs = context_pairs,mode="train", prefix=FLAGS.train_prefix)

    features = minibatch.features


    adj_info_ph = tf.placeholder(tf.int32, shape=minibatch.adj.shape)
    adj_info = tf.Variable(adj_info_ph, trainable=False, name="adj_info")

    if FLAGS.model == 'graphsage_mean':
        # Create model
        sampler = UniformNeighborSampler(adj_info)
        if FLAGS.samples_3 != 0:
            layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                                SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2),
                                SAGEInfo("node", sampler, FLAGS.samples_3, FLAGS.dim_2)]
        elif FLAGS.samples_2 != 0:
            layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                                SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)]
        else:
            layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1)]

        model = SupervisedGraphsage(num_classes, placeholders, 
                                     features,
                                     adj_info,
                                     minibatch.deg,
                                     layer_infos, 
                                     model_size=FLAGS.model_size,
                                     sigmoid_loss = FLAGS.sigmoid,
                                     identity_dim = FLAGS.identity_dim,
                                     logging=True)
    elif FLAGS.model == 'gcn':
        # Create model
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, 2*FLAGS.dim_1),
                            SAGEInfo("node", sampler, FLAGS.samples_2, 2*FLAGS.dim_2)]

        model = SupervisedGraphsage(num_classes, placeholders, 
                                     features,
                                     adj_info,
                                     minibatch.deg,
                                     layer_infos=layer_infos, 
                                     aggregator_type="gcn",
                                     model_size=FLAGS.model_size,
                                     concat=False,
                                     sigmoid_loss = FLAGS.sigmoid,
                                     identity_dim = FLAGS.identity_dim,
                                     logging=True)

    elif FLAGS.model == 'graphsage_seq':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)]

        model = SupervisedGraphsage(num_classes, placeholders, 
                                     features,
                                     adj_info,
                                     minibatch.deg,
                                     layer_infos=layer_infos, 
                                     aggregator_type="seq",
                                     model_size=FLAGS.model_size,
                                     sigmoid_loss = FLAGS.sigmoid,
                                     identity_dim = FLAGS.identity_dim,
                                     logging=True)

    elif FLAGS.model == 'graphsage_maxpool':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)]

        model = SupervisedGraphsage(num_classes, placeholders, 
                                    features,
                                    adj_info,
                                    minibatch.deg,
                                     layer_infos=layer_infos, 
                                     aggregator_type="maxpool",
                                     model_size=FLAGS.model_size,
                                     sigmoid_loss = FLAGS.sigmoid,
                                     identity_dim = FLAGS.identity_dim,
                                     logging=True)

    elif FLAGS.model == 'graphsage_meanpool':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)]

        model = SupervisedGraphsage(num_classes, placeholders, 
                                    features,
                                    adj_info,
                                    minibatch.deg,
                                     layer_infos=layer_infos, 
                                     aggregator_type="meanpool",
                                     model_size=FLAGS.model_size,
                                     sigmoid_loss = FLAGS.sigmoid,
                                     identity_dim = FLAGS.identity_dim,
                                     logging=True)

    else:
        raise Exception('Error: model name unrecognized.')

    config = tf.ConfigProto(log_device_placement=FLAGS.log_device_placement)
    config.gpu_options.allow_growth = True
    #config.gpu_options.per_process_gpu_memory_fraction = GPU_MEM_FRACTION
    config.allow_soft_placement = True
    
    # Initialize session
    sess = tf.Session(config=config)
    merged = tf.summary.merge_all()
    summary_writer = tf.summary.FileWriter(log_dir(), sess.graph)
     
    # Init variables
    sess.run(tf.global_variables_initializer(), feed_dict={adj_info_ph: minibatch.adj})
    
    # Train model
    print(layer_infos)
    total_steps = 0
    avg_time = 0.0
    epoch_val_costs = []

    train_adj_info = tf.assign(adj_info, minibatch.adj)
    val_adj_info = tf.assign(adj_info, minibatch.test_adj)




    # Load trained model
    if os.path.exists("TVYoutubesupervisedTrainedModel_MC_marginal/"):
        print("Entered in the Loop  =           == = == = == ")
        var_to_save = []
        for var in tf.trainable_variables():
            var_to_save.append(var)
        saver = tf.train.Saver(var_to_save)
        print("Trained Model Loading!")
        saver.restore(sess,"TVYoutubesupervisedTrainedModel_MC_marginal/model.ckpt")
        print("Trained Model Loaded!")

    # ------------------------------------------------------------------
    for epoch in range(FLAGS.epochs): 
        minibatch.shuffle() 

        iter = 0
        print('Epoch: %04d' % (epoch + 1))
        epoch_val_costs.append(0)
        while not minibatch.end():
            # Construct feed dictionary
            feed_dict, labels = minibatch.next_minibatch_feed_dict()
            feed_dict.update({placeholders['dropout']: FLAGS.dropout})

            t = time.time()
            # Training step
            outs = sess.run([merged, model.opt_op, model.loss, model.preds], feed_dict=feed_dict)
            train_cost = outs[2]

            #
            # if iter % FLAGS.validate_iter == 0:
            #     # Validation
            #     sess.run(val_adj_info.op)
            #     if FLAGS.validate_batch_size == -1:
            #         val_cost, duration = incremental_evaluate(sess, model, minibatch, FLAGS.batch_size)
            #     else:
            #         val_cost, duration = evaluate(sess, model, minibatch, FLAGS.validate_batch_size)
            #     sess.run(train_adj_info.op)
            #     epoch_val_costs[-1] += val_cost
            #
            # if total_steps % FLAGS.print_every == 0:
            #     summary_writer.add_summary(outs[0], total_steps)
    
            # Print results
            avg_time = (avg_time * total_steps + time.time() - t) / (total_steps + 1)


            if total_steps % FLAGS.print_every == 0:
                # train_f1_mic, train_f1_mac = calc_f1(labels, outs[-1])
                print("Iter:", '%04d' % iter, 
                      "train_loss=", "{:.8f}".format(train_cost),
                   #   "val_loss=", "{:.8f}".format(val_cost),
                      "time=", "{:.5f}".format(avg_time))
 
            iter += 1
            total_steps += 1

            if total_steps > FLAGS.max_total_steps:
                break

        if total_steps > FLAGS.max_total_steps:
                break
    
    print("Optimization Finished!")

    var_to_save = []
    for var in tf.trainable_variables():
    	var_to_save.append(var)
    saver = tf.train.Saver(var_to_save)
    save_path = saver.save(sess, "TVYoutubesupervisedTrainedModel_MC_marginal/model.ckpt")
    print("*** Saved: Model", save_path)
    
    # sess.run(val_adj_info.op)
    # val_cost, duration = incremental_evaluate(sess, model, minibatch, FLAGS.batch_size)
    # print("Full validation stats:",
    #               "loss=", "{:.12f}".format(val_cost),
    #               "time=", "{:.12f}".format(duration))
    # with open(log_dir() + "val_stats_train.txt", "w") as fp:
    #     fp.write("loss={:.5f} time={:.5f}".
    #             format(val_cost, duration))

    print("Writing test set stats to file (don't peak!)")
Пример #7
0
                                  id_map,
                                  placeholders,
                                  class_map,
                                  num_classes,
                                  batch_size=FLAGS.batch_size,
                                  max_degree=FLAGS.max_degree,
                                  context_pairs=context_pairs)
adj_info_ph = tf.placeholder(tf.int32, shape=minibatch.adj.shape)
adj_info = tf.Variable(adj_info_ph, trainable=False, name="adj_info")

if FLAGS.model == 'graphsage_mean':
    # Create model
    sampler = UniformNeighborSampler(adj_info)
    if FLAGS.samples_3 != 0:
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2),
            SAGEInfo("node", sampler, FLAGS.samples_3, FLAGS.dim_2)
        ]
    elif FLAGS.samples_2 != 0:
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]
    else:
        layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1)]

    model = SupervisedGraphsage(num_classes,
                                placeholders,
                                features,
                                adj_info,
Пример #8
0
def train(train_data, test_data=None):

    G = train_data[0]
    features = train_data[1]
    id_map = train_data[2]
    class_map  = train_data[4]
    if isinstance(list(class_map.values())[0], list):
        num_classes = len(list(class_map.values())[0])
    else:
        num_classes = len(set(class_map.values()))

    if not features is None:
        # pad with dummy zero vector
        features = np.vstack([features, np.zeros((features.shape[1],))])

    context_pairs = train_data[3] if FLAGS.random_context else None
    placeholders = construct_placeholders(num_classes)
    minibatch = NodeMinibatchIterator(G, 
            id_map,
            placeholders, 
            class_map,
            num_classes,
            batch_size=FLAGS.batch_size,
            max_degree=FLAGS.max_degree, 
            context_pairs = context_pairs)
    adj_info_ph = tf.placeholder(tf.int32, shape=minibatch.adj.shape)
    adj_info = tf.Variable(adj_info_ph, trainable=False, name="adj_info")
    
    if FLAGS.model == 'graphsage_mean':
        # Create model
        sampler = UniformNeighborSampler(adj_info)
        if FLAGS.samples_3 != 0:
            layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                                SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2),
                                SAGEInfo("node", sampler, FLAGS.samples_3, FLAGS.dim_2)]
        elif FLAGS.samples_2 != 0:
            layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                                SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)]
        else:
            layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1)]
        '''        
        ### 3 layer test
        layer_infos = [SAGEInfo("node", sampler, 50, FLAGS.dim_2),
                                SAGEInfo("node", sampler, 25, FLAGS.dim_2),
                                SAGEInfo("node", sampler, 10, FLAGS.dim_2)]
 
        '''
        model = SupervisedGraphsage(num_classes, placeholders, 
                                     features,
                                     adj_info,
                                     minibatch.deg,
                                     layer_infos, 
                                     model_size=FLAGS.model_size,
                                     sigmoid_loss = FLAGS.sigmoid,
                                     identity_dim = FLAGS.identity_dim,
                                     logging=True)
    elif FLAGS.model == 'gcn':
        # Create model
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, 2*FLAGS.dim_1),
                            SAGEInfo("node", sampler, FLAGS.samples_2, 2*FLAGS.dim_2)]

        model = SupervisedGraphsage(num_classes, placeholders, 
                                     features,
                                     adj_info,
                                     minibatch.deg,
                                     layer_infos=layer_infos, 
                                     aggregator_type="gcn",
                                     model_size=FLAGS.model_size,
                                     concat=False,
                                     sigmoid_loss = FLAGS.sigmoid,
                                     identity_dim = FLAGS.identity_dim,
                                     logging=True)

    elif FLAGS.model == 'graphsage_seq':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)]

        model = SupervisedGraphsage(num_classes, placeholders, 
                                     features,
                                     adj_info,
                                     minibatch.deg,
                                     layer_infos=layer_infos, 
                                     aggregator_type="seq",
                                     model_size=FLAGS.model_size,
                                     sigmoid_loss = FLAGS.sigmoid,
                                     identity_dim = FLAGS.identity_dim,
                                     logging=True)

    elif FLAGS.model == 'graphsage_maxpool':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)]

        model = SupervisedGraphsage(num_classes, placeholders, 
                                    features,
                                    adj_info,
                                    minibatch.deg,
                                     layer_infos=layer_infos, 
                                     aggregator_type="maxpool",
                                     model_size=FLAGS.model_size,
                                     sigmoid_loss = FLAGS.sigmoid,
                                     identity_dim = FLAGS.identity_dim,
                                     logging=True)

    elif FLAGS.model == 'graphsage_meanpool':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)]

        model = SupervisedGraphsage(num_classes, placeholders, 
                                    features,
                                    adj_info,
                                    minibatch.deg,
                                     layer_infos=layer_infos, 
                                     aggregator_type="meanpool",
                                     model_size=FLAGS.model_size,
                                     sigmoid_loss = FLAGS.sigmoid,
                                     identity_dim = FLAGS.identity_dim,
                                     logging=True)

    else:
        raise Exception('Error: model name unrecognized.')

    config = tf.ConfigProto(log_device_placement=FLAGS.log_device_placement)
    config.gpu_options.allow_growth = True
    #config.gpu_options.per_process_gpu_memory_fraction = GPU_MEM_FRACTION
    config.allow_soft_placement = True
    
    # Initialize session
    sess = tf.Session(config=config)
    merged = tf.summary.merge_all()
    summary_writer = tf.summary.FileWriter(log_dir(), sess.graph)
    
    # Save model
    saver = tf.train.Saver()
    model_path =  './model/' + FLAGS.train_prefix.split('/')[-1] + '-' + FLAGS.model_prefix + '/'
    if not os.path.exists(model_path):
        os.makedirs(model_path)


    # Init variables
    sess.run(tf.global_variables_initializer(), feed_dict={adj_info_ph: minibatch.adj})
    
    # Train model
    
    total_steps = 0
    avg_time = 0.0
    epoch_val_costs = []

    train_adj_info = tf.assign(adj_info, minibatch.adj)
    val_adj_info = tf.assign(adj_info, minibatch.test_adj)
    
    
    val_cost_ = []
    val_f1_mic_ = []
    val_f1_mac_ = []
    duration_ = []
    
    for epoch in range(FLAGS.epochs): 
        minibatch.shuffle() 

        iter = 0
        print('Epoch: %04d' % (epoch + 1))
        epoch_val_costs.append(0)
        while not minibatch.end():
            # Construct feed dictionary
            feed_dict, labels = minibatch.next_minibatch_feed_dict()
            feed_dict.update({placeholders['dropout']: FLAGS.dropout})

            t = time.time()
            # Training step
            outs = sess.run([merged, model.opt_op, model.loss, model.preds], feed_dict=feed_dict)
            train_cost = outs[2]

            if iter % FLAGS.validate_iter == 0:
                # Validation
                sess.run(val_adj_info.op)
                if FLAGS.validate_batch_size == -1:
                    val_cost, val_f1_mic, val_f1_mac, duration = incremental_evaluate(sess, model, minibatch, FLAGS.batch_size)
                else:
                    val_cost, val_f1_mic, val_f1_mac, duration = evaluate(sess, model, minibatch, FLAGS.validate_batch_size)
                
                # accumulate val results
                val_cost_.append(val_cost)
                val_f1_mic_.append(val_f1_mic)
                val_f1_mac_.append(val_f1_mac)
                duration_.append(duration)

                #
                sess.run(train_adj_info.op)
                epoch_val_costs[-1] += val_cost


            if total_steps % FLAGS.print_every == 0:
                summary_writer.add_summary(outs[0], total_steps)
    
            # Print results
            avg_time = (avg_time * total_steps + time.time() - t) / (total_steps + 1)

            if total_steps % FLAGS.print_every == 0:
                train_f1_mic, train_f1_mac = calc_f1(labels, outs[-1])
                print("Iter:", '%04d' % iter, 
                      "train_loss=", "{:.5f}".format(train_cost),
                      "train_f1_mic=", "{:.5f}".format(train_f1_mic), 
                      "train_f1_mac=", "{:.5f}".format(train_f1_mac), 
                      "val_loss=", "{:.5f}".format(val_cost),
                      "val_f1_mic=", "{:.5f}".format(val_f1_mic), 
                      "val_f1_mac=", "{:.5f}".format(val_f1_mac), 
                      "time=", "{:.5f}".format(avg_time))
 
            iter += 1
            total_steps += 1

            if total_steps > FLAGS.max_total_steps:
                break

        if total_steps > FLAGS.max_total_steps:
                break
        
    
        # Save model
        save_path = saver.save(sess, model_path+'model.ckpt')
        print ('model is saved at %s'%save_path)
    

    print("Validation per epoch in training")
    for ep in range(FLAGS.epochs):
        print("Epoch: %04d"%ep, " val_cost={:.5f}".format(val_cost_[ep]), " val_f1_mic={:.5f}".format(val_f1_mic_[ep]), " val_f1_mac={:.5f}".format(val_f1_mac_[ep]), " duration={:.5f}".format(duration_[ep]))
    
    print("Optimization Finished!")
    sess.run(val_adj_info.op)

    # full validation 
    val_cost_ = []
    val_f1_mic_ = []
    val_f1_mac_ = []
    duration_ = []
    for iter in range(10):
        val_cost, val_f1_mic, val_f1_mac, duration = incremental_evaluate(sess, model, minibatch, FLAGS.batch_size)
        print("Full validation stats:",
                          "loss=", "{:.5f}".format(val_cost),
                          "f1_micro=", "{:.5f}".format(val_f1_mic),
                          "f1_macro=", "{:.5f}".format(val_f1_mac),
                          "time=", "{:.5f}".format(duration))

        val_cost_.append(val_cost)
        val_f1_mic_.append(val_f1_mic)
        val_f1_mac_.append(val_f1_mac)
        duration_.append(duration)
  
    # write validation results
    with open(log_dir() + "val_stats.txt", "w") as fp:
        for iter in range(10):
            fp.write("loss={:.5f} f1_micro={:.5f} f1_macro={:.5f} time={:.5f}\n".format(val_cost_[iter], val_f1_mic_[iter], val_f1_mac_[iter], duration_[iter]))
        
        fp.write("mean: loss={:.5f} f1_micro={:.5f} f1_macro={:.5f} time={:.5f}\n".format(np.mean(val_cost_), np.mean(val_f1_mic_), np.mean(val_f1_mac_), np.mean(duration_)))
        fp.write("variance: loss={:.5f} f1_micro={:.5f} f1_macro={:.5f} time={:.5f}\n".format(np.var(val_cost_), np.var(val_f1_mic_), np.var(val_f1_mac_), np.var(duration_)))
        

    # test 
    val_cost_ = []
    val_f1_mic_ = []
    val_f1_mac_ = []
    duration_ = []

    print("Writing test set stats to file (don't peak!)")
    for iter in range(10):
        val_cost, val_f1_mic, val_f1_mac, duration = incremental_evaluate(sess, model, minibatch, FLAGS.batch_size, test=True)
    
        val_cost_.append(val_cost)
        val_f1_mic_.append(val_f1_mic)
        val_f1_mac_.append(val_f1_mac)
        duration_.append(duration)
   
    # write test results
    with open(log_dir() + "test_stats.txt", "w") as fp:
        for iter in range(10):
            fp.write("loss={:.5f} f1_micro={:.5f} f1_macro={:.5f} time={:.5f}\n".
                        format(val_cost_[iter], val_f1_mic_[iter], val_f1_mac_[iter], duration_[iter]))
        
        fp.write("mean: loss={:.5f} f1_micro={:.5f} f1_macro={:.5f} time={:.5f}\n".
                        format(np.mean(val_cost_), np.mean(val_f1_mic_), np.mean(val_f1_mac_), np.mean(duration_)))
        fp.write("variance: loss={:.5f} f1_micro={:.5f} f1_macro={:.5f} time={:.5f}\n".
                        format(np.var(val_cost_), np.var(val_f1_mic_), np.var(val_f1_mac_), np.var(duration_)))
Пример #9
0
def train(train_data, test_data=None):
    G = train_data[0]
    features = train_data[1]
    id_map = train_data[2]

    if not features is None:
        # pad with dummy zero vector
        features = np.vstack([features, np.zeros((features.shape[1], ))])

    context_pairs = train_data[3] if FLAGS.random_context else None
    placeholders = construct_placeholders()
    minibatch = EdgeMinibatchIterator(G,
                                      id_map,
                                      placeholders,
                                      batch_size=FLAGS.batch_size,
                                      max_degree=FLAGS.max_degree,
                                      num_neg_samples=FLAGS.neg_sample_size,
                                      context_pairs=context_pairs)
    adj_info_ph = tf.placeholder(tf.int32, shape=minibatch.adj.shape)
    adj_info = tf.Variable(adj_info_ph, trainable=False, name="adj_info")

    if FLAGS.model == 'graphsage_mean':
        # Create model
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   logging=True)
    elif FLAGS.model == 'gcn':
        # Create model
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, 2 * FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, 2 * FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   aggregator_type="gcn",
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   concat=False,
                                   logging=True)

    elif FLAGS.model == 'graphsage_seq':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   identity_dim=FLAGS.identity_dim,
                                   aggregator_type="seq",
                                   model_size=FLAGS.model_size,
                                   logging=True)

    elif FLAGS.model == 'graphsage_maxpool':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   aggregator_type="maxpool",
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   logging=True)
    elif FLAGS.model == 'graphsage_meanpool':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   aggregator_type="meanpool",
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   logging=True)

    elif FLAGS.model == 'n2v':
        model = Node2VecModel(
            placeholders,
            features.shape[0],
            minibatch.deg,
            #2x because graphsage uses concat
            nodevec_dim=2 * FLAGS.dim_1,
            lr=FLAGS.learning_rate)
    else:
        raise Exception('Error: model name unrecognized.')

    config = tf.ConfigProto(log_device_placement=FLAGS.log_device_placement)
    config.gpu_options.allow_growth = True
    #config.gpu_options.per_process_gpu_memory_fraction = GPU_MEM_FRACTION
    config.allow_soft_placement = True

    # Initialize session
    sess = tf.Session(config=config)
    merged = tf.summary.merge_all()
    summary_writer = tf.summary.FileWriter(log_dir(), sess.graph)

    # Init variables
    # sess.run(feed_dict={adj_info_ph: minibatch.adj})

    train_adj_info = tf.assign(adj_info, minibatch.adj)
    val_adj_info = tf.assign(adj_info, minibatch.test_adj)

    saver = tf.train.Saver()
    saver.restore(sess, "supervisedTrainedModel_MC/model.ckpt")

    if FLAGS.save_embeddings:
        sess.run(val_adj_info.op)
        save_val_embeddings(sess, model, minibatch, FLAGS.validate_batch_size,
                            log_dir())
Пример #10
0
def train(train_data, test_data=None):
    G = train_data[0]  # G 是一个Networkx里的对象,这几个都是经过load_data()处理过的
    features = train_data[1]
    id_map = train_data[2]
    class_map1 = train_data[4]
    class_map2 = train_data[5]
    class_map3 = train_data[6]
    dict_classmap = {
        0: class_map1,
        1: class_map2,
        2: class_map3,
        3: class_map3
    }
    hierarchy = FLAGS.hierarchy
    features_shape1 = None
    a_class = construct_class_numpy(class_map1)
    b_class = construct_class_numpy(class_map2)
    c_class = construct_class_numpy(class_map3)
    a_class = tf.cast(a_class, tf.float32)
    b_class = tf.cast(b_class, tf.float32)
    c_class = tf.cast(c_class, tf.float32)

    num_class = []
    #    for key in class_map.keys():
    #        num_class = num_class.append(sum(class_map[key]))

    for hi_num in range(hierarchy):
        #tf.reset_default_graph()
        if hi_num == 0:
            class_map = class_map1
            features = features
            features_shape1 = features.shape[1]
            if features is not None:
                # pad with dummy zero vector
                features = np.vstack(
                    [features, np.zeros((features.shape[1], ))])
            features = tf.cast(features, tf.float32)

        else:
            print("hierarchy %d finished" % (hi_num), end='\n\n')
            class_map = dict_classmap[hi_num]
            features = features2
            features = tf.cast(features, tf.float32)
            features = tf.concat(
                [features,
                 tf.zeros([1, features_shape1 + num_classes])],
                axis=0)
            features_shape1 = features.shape[1]

        if hi_num == 0:
            if isinstance(list(class_map.values())[0], list):
                num_classes = len(list(class_map.values())[0])
            else:
                num_classes = len(set(class_map.values()))
        else:
            if isinstance(list(dict_classmap[hi_num].values())[0], list):
                num_classes = len(list(dict_classmap[hi_num].values())[0])
            else:
                num_classes = len(set(dict_classmap[hi_num].values()))
        """"" 
        if features is not None:
            # pad with dummy zero vector
            features = np.vstack([features, np.zeros((features.shape[1],))])
        """ ""

        # features = tf.cast(features, tf.float32)
        # embeding_weight=tf.get_variable('emb_weights', [50, 128], initializer=tf.random_normal_initializer(),dtype=tf.float32)
        # features=tf.matmul(features,embeding_weight)
        context_pairs = train_data[3] if FLAGS.random_context else None
        placeholders = construct_placeholders(num_classes)
        minibatch = NodeMinibatchIterator(G,
                                          id_map,
                                          placeholders,
                                          class_map,
                                          num_classes,
                                          batch_size=FLAGS.batch_size,
                                          max_degree=FLAGS.max_degree,
                                          context_pairs=context_pairs)
        ##########
        with open('test_nodes.txt', 'w') as f:
            json.dump(minibatch.test_nodes, f)
        ###########
        if hi_num == 0:
            adj_info_ph = tf.placeholder(tf.int32,
                                         shape=minibatch.adj.shape,
                                         name='adj_info_ph')

        # 把adj_info设成Variable应该是因为在训练和测试时会改变adj_info的值,所以
        # 用Varible然后用tf.assign()赋值。
        adj_info = tf.Variable(adj_info_ph, trainable=False, name="adj_info")

        shap.initjs()
        if FLAGS.model == 'graphsage_mean':
            # Create model
            sampler = UniformNeighborSampler(adj_info)

            if FLAGS.samples_3 != 0:
                layer_infos = [
                    SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                    SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2),
                    SAGEInfo("node", sampler, FLAGS.samples_3, FLAGS.dim_2)
                ]

            elif FLAGS.samples_2 != 0:
                layer_infos = [
                    SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                    SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
                ]

            else:
                layer_infos = [
                    SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1)
                ]

            model = SupervisedGraphsage(
                num_classes,
                placeholders,
                features,
                adj_info,
                minibatch.deg,  # 每一个的度
                layer_infos,
                model_size=FLAGS.model_size,
                sigmoid_loss=FLAGS.sigmoid,
                identity_dim=FLAGS.identity_dim,
                logging=True,
                concat=True,
            )

        elif FLAGS.model == 'gcn':
            # Create model
            sampler = UniformNeighborSampler(adj_info)
            layer_infos = [
                SAGEInfo("node", sampler, FLAGS.samples_1, 2 * FLAGS.dim_1),
                SAGEInfo("node", sampler, FLAGS.samples_2, 2 * FLAGS.dim_2)
            ]

            model = SupervisedGraphsage(num_classes,
                                        placeholders,
                                        features,
                                        adj_info,
                                        minibatch.deg,
                                        layer_infos=layer_infos,
                                        aggregator_type="gcn",
                                        model_size=FLAGS.model_size,
                                        concat=False,
                                        sigmoid_loss=FLAGS.sigmoid,
                                        identity_dim=FLAGS.identity_dim,
                                        logging=True)

        elif FLAGS.model == 'graphsage_seq':
            sampler = UniformNeighborSampler(adj_info)
            layer_infos = [
                SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
            ]

            model = SupervisedGraphsage(num_classes,
                                        placeholders,
                                        features,
                                        adj_info,
                                        minibatch.deg,
                                        layer_infos=layer_infos,
                                        aggregator_type="seq",
                                        model_size=FLAGS.model_size,
                                        sigmoid_loss=FLAGS.sigmoid,
                                        identity_dim=FLAGS.identity_dim,
                                        logging=True,
                                        concat=True)

        elif FLAGS.model == 'graphsage_maxpool':
            sampler = UniformNeighborSampler(adj_info)
            layer_infos = [
                SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
            ]

            model = SupervisedGraphsage(num_classes,
                                        placeholders,
                                        features,
                                        adj_info,
                                        minibatch.deg,
                                        layer_infos=layer_infos,
                                        aggregator_type="maxpool",
                                        model_size=FLAGS.model_size,
                                        sigmoid_loss=FLAGS.sigmoid,
                                        identity_dim=FLAGS.identity_dim,
                                        logging=True,
                                        concat=True)

        elif FLAGS.model == 'graphsage_meanpool':
            sampler = UniformNeighborSampler(adj_info)
            layer_infos = [
                SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
            ]

            model = SupervisedGraphsage(num_classes,
                                        placeholders,
                                        features,
                                        adj_info,
                                        minibatch.deg,
                                        layer_infos=layer_infos,
                                        aggregator_type="meanpool",
                                        model_size=FLAGS.model_size,
                                        sigmoid_loss=FLAGS.sigmoid,
                                        identity_dim=FLAGS.identity_dim,
                                        logging=True,
                                        concat=True)
        elif FLAGS.model == 'gat':
            sampler = UniformNeighborSampler(adj_info)
            # 建立两层网络 采样邻居、邻居个数、输出维度
            layer_infos = [
                SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
            ]

            model = SupervisedGraphsage(
                num_classes,
                placeholders,
                features,
                adj_info,
                minibatch.deg,
                concat=True,
                layer_infos=layer_infos,
                aggregator_type="gat",
                model_size=FLAGS.model_size,
                sigmoid_loss=FLAGS.sigmoid,
                identity_dim=FLAGS.identity_dim,
                logging=True,
            )
        else:
            raise Exception('Error: model name unrecognized.')

        config = tf.ConfigProto(
            log_device_placement=FLAGS.log_device_placement)
        config.gpu_options.allow_growth = True
        config.gpu_options.per_process_gpu_memory_fraction = GPU_MEM_FRACTION
        config.allow_soft_placement = True

        # Initialize session

        sess = tf.Session(config=config)
        # sess = tf_dbg.LocalCLIDebugWrapperSession(sess)
        #merged = tf.summary.merge_all()  # 将所有东西保存到磁盘,可视化会用到
        #summary_writer = tf.summary.FileWriter(log_dir(), sess.graph)  # 记录信息,可视化,可以用tensorboard查看

        # Init variables

        sess.run(tf.global_variables_initializer(),
                 feed_dict={adj_info_ph: minibatch.adj})
        #sess.run(tf.global_variables_initializer(), feed_dict={adj_info_ph2: minibatch2.adj})

        # Train model
        total_steps = 0
        avg_time = 0.0
        epoch_val_costs = []
        epoch_val_costs2 = []
        # 这里minibatch.adj和minibathc.test_adj的大小是一样的,只不过adj里面把不是train的值都变成一样
        # val在这里是validation的意思,验证
        train_adj_info = tf.assign(
            adj_info, minibatch.adj
        )  # tf.assign()是为一个tf.Variable赋值,返回值是一个Variable,是赋值后的值
        val_adj_info = tf.assign(
            adj_info,
            minibatch.test_adj)  # assign()是一个Opration,要用sess.run()才能执行

        it = 0
        train_loss = []
        val_loss = []
        train_f1_mics = []
        val_f1_mics = []
        loss_plt = []
        loss_plt2 = []
        trainf1mi = []
        trainf1ma = []
        valf1mi = []
        valf1ma = []
        iter_num = 0

        if hi_num == 0:
            epochs = FLAGS.epochs
        elif hi_num == 1:
            epochs = FLAGS.epochs2
        elif hi_num == 2:
            epochs = FLAGS.epochs3
        else:
            epochs = FLAGS.epochs4

        for epoch in range(epochs + 1):
            if epoch < epochs:
                minibatch.shuffle()
                iter = 0
                print('Epoch: %04d' % (epoch + 1))
                epoch_val_costs.append(0)
                while not minibatch.end():
                    # Construct feed dictionary
                    # 通过改变feed_dict来改变每次minibatch的节点
                    feed_dict, labels = minibatch.next_minibatch_feed_dict(
                    )  # feed_dict是mibatch修改过的placeholder
                    feed_dict.update({placeholders['dropout']: FLAGS.dropout})
                    t = time.time()
                    # Training step
                    outs = sess.run([model.opt_op, model.loss, model.preds],
                                    feed_dict=feed_dict)
                    train_cost = outs[1]
                    iter_num = iter_num + 1
                    loss_plt.append(float(train_cost))
                    if iter % FLAGS.print_every == 0:
                        # Validation 验证集
                        sess.run(val_adj_info.op
                                 )  # sess.run()  fetch参数是一个Opration,代表执行这个操作。
                        if FLAGS.validate_batch_size == -1:
                            val_cost, val_f1_mic, val_f1_mac, duration, otu_lazy, _ = incremental_evaluate(
                                sess, model, minibatch, FLAGS.batch_size)
                        else:
                            val_cost, val_f1_mic, val_f1_mac, duration = evaluate(
                                sess, model, minibatch,
                                FLAGS.validate_batch_size)
                        sess.run(train_adj_info.op
                                 )  # 每一个tensor都有op属性,代表产生这个张量的opration。
                        epoch_val_costs[-1] += val_cost

                    #if iter % FLAGS.print_every == 0:
                    #summary_writer.add_summary(outs[0], total_steps)

                    # Print results
                    avg_time = (avg_time * total_steps + time.time() -
                                t) / (total_steps + 1)
                    loss_plt2.append(float(val_cost))
                    valf1mi.append(float(val_f1_mic))
                    valf1ma.append(float(val_f1_mac))

                    if iter % FLAGS.print_every == 0:
                        train_f1_mic, train_f1_mac, train_f1_none = calc_f1(
                            labels, outs[-1])
                        trainf1mi.append(float(train_f1_mic))
                        trainf1ma.append(float(train_f1_mac))

                        print(
                            "Iter:",
                            '%04d' % iter,
                            # 训练集上的损失函数等信息
                            "train_loss=",
                            "{:.5f}".format(train_cost),
                            "train_f1_mic=",
                            "{:.5f}".format(train_f1_mic),
                            "train_f1_mac=",
                            "{:.5f}".format(train_f1_mac),
                            # 在测试集上的损失函数值等信息
                            "val_loss=",
                            "{:.5f}".format(val_cost),
                            "val_f1_mic=",
                            "{:.5f}".format(val_f1_mic),
                            "val_f1_mac=",
                            "{:.5f}".format(val_f1_mac),
                            "time=",
                            "{:.5f}".format(avg_time))
                        train_loss.append(train_cost)
                        val_loss.append(val_cost)
                        train_f1_mics.append(train_f1_mic)
                        val_f1_mics.append(val_f1_mic)
                    iter += 1
                    total_steps += 1
                    if total_steps > FLAGS.max_total_steps:
                        break
                if total_steps > FLAGS.max_total_steps:
                    break

            # concat features
            elif hi_num == FLAGS.hierarchy - 1:
                print("the last outputs")
            else:
                iter = 0
                minibatch.shuffle()
                while not minibatch.end():
                    print("Iter:", '%04d' % iter, "concat")
                    feed_dict, labels = minibatch.next_minibatch_feed_dict(
                    )  # feed_dict是mibatch修改过的placeholder
                    feed_dict.update({placeholders['dropout']: FLAGS.dropout})
                    x = feed_dict[placeholders['batch']]
                    outs = sess.run([
                        model.opt_op, model.loss, model.preds, model.node_preds
                    ],
                                    feed_dict=feed_dict)
                    features_tail = outs[3]
                    features_tail = tf.cast(features_tail, tf.float32)
                    """""
                    if hi_num == 0:
                        features_tail = tf.nn.embedding_lookup(a_class, feed_dict[placeholders["batch"]])
                    elif hi_num == 1:
                        features_tail = tf.nn.embedding_lookup(b_class, feed_dict[placeholders["batch"]])
                    else:
                        features_tail = tf.nn.embedding_lookup(c_class, feed_dict[placeholders["batch"]])
                    """ ""
                    hidden = tf.nn.embedding_lookup(
                        features, feed_dict[placeholders["batch"]])
                    features_inter = tf.concat([hidden, features_tail], axis=1)

                    if iter == 0:
                        features2 = features_inter
                    else:
                        features2 = tf.concat([features2, features_inter],
                                              axis=0)
                    iter += 1

                # val features & test features
                iter_num = 0
                finished = False
                while not finished:
                    feed_dict_val, batch_labels, finished, _ = minibatch.incremental_node_val_feed_dict(
                        FLAGS.batch_size, iter_num, test=False)
                    node_outs_val = sess.run(
                        [model.preds, model.loss, model.node_preds],
                        feed_dict=feed_dict_val)
                    tail_val = tf.cast(node_outs_val[2], tf.float32)
                    hidden_val = tf.nn.embedding_lookup(
                        features, feed_dict_val[placeholders["batch"]])
                    features_inter_val = tf.concat([hidden_val, tail_val],
                                                   axis=1)
                    iter_num += 1
                    features2 = tf.concat([features2, features_inter_val],
                                          axis=0)
                print("val features finished")
                iter_num = 0
                finished = False
                while not finished:
                    feed_dict_test, batch_labels, finished, _ = minibatch.incremental_node_val_feed_dict(
                        FLAGS.batch_size, iter_num, test=True)
                    node_outs_test = sess.run(
                        [model.preds, model.loss, model.node_preds],
                        feed_dict=feed_dict_test)
                    tail_test = tf.cast(node_outs_test[2], tf.float32)
                    hidden_test = tf.nn.embedding_lookup(
                        features, feed_dict_test[placeholders["batch"]])
                    features_inter_test = tf.concat([hidden_test, tail_test],
                                                    axis=1)
                    iter_num += 1
                    features2 = tf.concat([features2, features_inter_test],
                                          axis=0)
                print("test features finished")

                print("finish features concat")
                #features2 = sess.run(features2)

    print("Optimization Finished!")
    sess.run(val_adj_info.op)
    val_cost, val_f1_mic, val_f1_mac, duration, otu_f1, ko_none = incremental_evaluate(
        sess, model, minibatch, FLAGS.batch_size, test=True)
    print("Full validation stats:", "loss=", "{:.5f}".format(val_cost),
          "f1_micro=", "{:.5f}".format(val_f1_mic), "f1_macro=",
          "{:.5f}".format(val_f1_mac), "time=", "{:.5f}".format(duration))
    pred = y_ture_pre(sess, model, minibatch, FLAGS.batch_size)
    for i in range(pred.shape[0]):
        sum = 0
        for l in range(pred.shape[1]):
            sum = sum + pred[i, l]
        for m in range(pred.shape[1]):
            pred[i, m] = pred[i, m] / sum
    id = json.load(open(FLAGS.train_prefix + "-id_map.json"))
    # x_train = np.empty([pred.shape[0], array.s)
    num = 0
    session = tf.Session()
    array = session.run(features)
    x_test = np.empty([pred.shape[0], array.shape[1]])
    x_train = np.empty([len(minibatch.train_nodes), array.shape[1]])
    for node in minibatch.val_nodes:
        x_test[num] = array[id[node]]
        num = num + 1
    num1 = 0
    for node in minibatch.train_nodes:
        x_train[num1] = array[id[node]]
        num1 = num1 + 1

    with open(log_dir() + "val_stats.txt", "w") as fp:
        fp.write(
            "loss={:.5f} f1_micro={:.5f} f1_macro={:.5f} time={:.5f}".format(
                val_cost, val_f1_mic, val_f1_mac, duration))

    print("Writing test set stats to file (don't peak!)")
    val_cost, val_f1_mic, val_f1_mac, duration, otu_lazy, ko_none = incremental_evaluate(
        sess, model, minibatch, FLAGS.batch_size, test=True)
    with open(log_dir() + "test_stats.txt", "w") as fp:
        fp.write("loss={:.5f} f1_micro={:.5f} f1_macro={:.5f}".format(
            val_cost, val_f1_mic, val_f1_mac))

    incremental_evaluate_for_each(sess,
                                  model,
                                  minibatch,
                                  FLAGS.batch_size,
                                  test=True)

    ##################################################################################################################
    # plot loss
    plt.figure()
    plt.plot(loss_plt, label='train_loss')
    plt.plot(loss_plt2, label='val_loss')
    plt.legend(loc=0)
    plt.xlabel('Iteration')
    plt.ylabel('loss')
    plt.title('Loss plot')
    plt.grid(True)
    plt.axis('tight')
    #plt.savefig("./graph/HMC12_loss.png")
    # plt.show()

    # plot f1 score
    plt.figure()
    plt.subplot(211)
    plt.plot(trainf1mi, label='train_f1_micro')
    plt.plot(valf1mi, label='val_f1_micro')
    plt.legend(loc=0)
    plt.xlabel('Iterations')
    plt.ylabel('f1_micro')
    plt.title('train_val_f1_score')
    plt.grid(True)
    plt.axis('tight')

    plt.subplot(212)
    plt.plot(trainf1ma, label='train_f1_macro')
    plt.plot(valf1ma, label='val_f1_macro')
    plt.legend(loc=0)
    plt.xlabel('Iteration')
    plt.ylabel('f1_macro')
    plt.grid(True)
    plt.axis('tight')
    #plt.savefig("./graph/HMC123_f1.png")
    # plt.show()

    plt.figure()
    plt.plot(np.arange(len(train_loss)) + 1, train_loss, label='train')
    plt.plot(np.arange(len(val_loss)) + 1, val_loss, label='val')
    plt.legend()
    plt.savefig('loss.png')
    plt.figure()
    plt.plot(np.arange(len(train_f1_mics)) + 1, train_f1_mics, label='train')
    plt.plot(np.arange(len(val_f1_mics)) + 1, val_f1_mics, label='val')
    plt.legend()
    plt.savefig('f1.png')

    # OTU f1
    plt.figure()
    plt.plot(otu_f1, label='otu_f1')
    plt.legend(loc=0)
    plt.xlabel('OTU')
    plt.ylabel('f1_score')
    plt.title('OTU f1 plot')
    plt.grid(True)
    plt.axis('tight')
    #plt.savefig("./graph/HMC123_otu_f1.png")
    # plt.show()

    #Ko f1 score
    plt.figure()
    plt.plot(ko_none, label='Ko f1 score')
    plt.legend(loc=0)
    plt.xlabel('Ko')
    plt.ylabel('f1_score')
    plt.grid(True)
    plt.axis('tight')
    #plt.savefig("./graph/HMC123_ko_f1.png")

    bad_ko = []
    b02 = 0
    b05 = 0
    b07 = 0
    for i in range(len(ko_none)):
        if ko_none[i] < 0.2:
            bad_ko.append(i)
            b02 += 1
            bad_ko = np.array(bad_ko)
        elif ko_none[i] < 0.5:
            b05 += 1
        elif ko_none[i] < 0.7:
            b07 += 1
    print("ko f1 below 0.2:", b02)
    print("ko f1 below 0.5:", b05)
    print("ko f1 below 0.7:", b07)
Пример #11
0
def train(train_data,
          log_dir,
          Theta=None,
          fixed_neigh_weights=None,
          test_data=None,
          neg_sample_weights=None):
    G = train_data[0]
    Gneg = Graph_complement(G)
    features = train_data[1]
    id_map = train_data[2]
    Adj_mat = Edges_to_Adjacency_mat(G.edges(), len(G.nodes()))
    #     print('A in unsup: ', Adj_mat)
    #     negAdj_mat = Edges_to_Adjacency_mat(Gneg.edges(), len(Gneg.nodes()))
    FLAGS.batch_size = batch_size_def(len(G.edges()))
    FLAGS.negbatch_size = batch_size_def(len(Gneg.edges()))
    FLAGS.samples_1 = min(25, len(G.nodes()))
    FLAGS.samples_2 = min(10, len(G.nodes()))
    FLAGS.negsamples_1 = min(25, len(Gneg.nodes()))
    FLAGS.negsamples_2 = min(10, len(Gneg.nodes()))
    FLAGS.neg_sample_size = FLAGS.batch_size

    if (len(G.nodes()) <= small_big_threshold):
        FLAGS.model_size = 'small'
    else:
        FLAGS.model_size = 'big'

    if not features is None:
        # pad with dummy zero vector
        features = np.vstack([features, np.zeros((features.shape[1], ))])

    context_pairs = train_data[3] if FLAGS.random_context else None
    placeholders = construct_placeholders()
    #print('placeholders: ', placeholders)
    minibatch = EdgeMinibatchIterator(G,
                                      id_map,
                                      placeholders,
                                      batch_size=FLAGS.batch_size,
                                      max_degree=FLAGS.max_degree,
                                      num_neg_samples=FLAGS.neg_sample_size,
                                      context_pairs=context_pairs)
    aggbatch_size = len(minibatch.agg_batch_Z1)
    negaggbatch_size = len(minibatch.agg_batch_Z3)
    negminibatch = EdgeMinibatchIterator(Gneg,
                                         id_map,
                                         placeholders,
                                         batch_size=FLAGS.negbatch_size,
                                         max_degree=FLAGS.max_degree,
                                         num_neg_samples=FLAGS.neg_sample_size,
                                         context_pairs=context_pairs)

    #adj_info = tf.Variable(placeholders['adj_info_ph'], trainable=False, name="adj_info")
    adj_info_ph = tf.placeholder(tf.int32,
                                 shape=minibatch.adj.shape,
                                 name="adj_info_ph")
    adj_info = tf.Variable(adj_info_ph, trainable=False, name="adj_info")

    if FLAGS.model == 'graphsage_mean':
        # Create model
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1,
                     FLAGS.negsamples_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2,
                     FLAGS.negsamples_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   Adj_mat=Adj_mat,
                                   non_edges=Gneg.edges(),
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   logging=True,
                                   fixed_theta_1=Theta,
                                   fixed_neigh_weights=fixed_neigh_weights,
                                   neg_sample_weights=neg_sample_weights,
                                   aggbatch_size=aggbatch_size,
                                   negaggbatch_size=negaggbatch_size)
    elif FLAGS.model == 'gcn':
        # Create model
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, 2 * FLAGS.dim_1,
                     FLAGS.negsamples_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, 2 * FLAGS.dim_2,
                     FLAGS.negsamples_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   aggregator_type="gcn",
                                   Adj_mat=Adj_mat,
                                   non_edges=Gneg.edges(),
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   concat=False,
                                   logging=True,
                                   fixed_theta_1=Theta,
                                   fixed_neigh_weights=fixed_neigh_weights,
                                   neg_sample_weights=neg_sample_weights,
                                   aggbatch_size=aggbatch_size,
                                   negaggbatch_size=negaggbatch_size)

    elif FLAGS.model == 'graphsage_seq':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1,
                     FLAGS.negsamples_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2,
                     FLAGS.negsamples_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   identity_dim=FLAGS.identity_dim,
                                   aggregator_type="seq",
                                   Adj_mat=Adj_mat,
                                   non_edges=Gneg.edges(),
                                   model_size=FLAGS.model_size,
                                   logging=True,
                                   fixed_theta_1=Theta,
                                   fixed_neigh_weights=fixed_neigh_weights,
                                   neg_sample_weights=neg_sample_weights,
                                   aggbatch_size=aggbatch_size,
                                   negaggbatch_size=negaggbatch_size)

    elif FLAGS.model == 'graphsage_maxpool':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1,
                     FLAGS.negsamples_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2,
                     FLAGS.negsamples_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   aggregator_type="maxpool",
                                   Adj_mat=Adj_mat,
                                   non_edges=Gneg.edges(),
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   logging=True,
                                   fixed_theta_1=Theta,
                                   fixed_neigh_weights=fixed_neigh_weights,
                                   neg_sample_weights=neg_sample_weights,
                                   aggbatch_size=aggbatch_size,
                                   negaggbatch_size=negaggbatch_size)

    elif FLAGS.model == 'graphsage_meanpool':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1,
                     FLAGS.negsamples_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2,
                     FLAGS.negsamples_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   aggregator_type="meanpool",
                                   Adj_mat=Adj_mat,
                                   non_edges=Gneg.edges(),
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   logging=True,
                                   fixed_theta_1=Theta,
                                   fixed_neigh_weights=fixed_neigh_weights,
                                   neg_sample_weights=neg_sample_weights,
                                   aggbatch_size=aggbatch_size,
                                   negaggbatch_size=negaggbatch_size)

    elif FLAGS.model == 'n2v':
        model = Node2VecModel(
            placeholders,
            features.shape[0],
            minibatch.deg,
            #2x because graphsage uses concat
            nodevec_dim=2 * FLAGS.dim_1,
            lr=FLAGS.learning_rate)
    else:
        raise Exception('Error: model name unrecognized.')

    config = tf.ConfigProto(log_device_placement=FLAGS.log_device_placement)
    config.gpu_options.allow_growth = True
    #config.gpu_options.per_process_gpu_memory_fraction = GPU_MEM_FRACTION
    config.allow_soft_placement = True

    # Initialize session
    sess = tf.Session(config=config)
    merged = tf.summary.merge_all()
    #summary_writer = tf.summary.FileWriter(log_dir(), sess.graph)

    # Init variables
    #minibatch.adj = minibatch.adj.astype(np.int32)
    #print('minibatch.adj.shape: %s, dtype: %s' % (minibatch.adj.shape, np.ndarray.dtype(minibatch.adj)))
    sess.run(tf.global_variables_initializer(),
             feed_dict={adj_info_ph: minibatch.adj})

    # Train model

    train_shadow_mrr = None
    shadow_mrr = None

    total_steps = 0
    avg_time = 0.0
    epoch_val_costs = []
    train_adj_info = tf.assign(adj_info, minibatch.adj)
    val_adj_info = tf.assign(adj_info, minibatch.test_adj)

    print_flag = False
    if (Theta is None):
        print_flag = True

    for epoch in range(FLAGS.epochs):

        minibatch.shuffle()
        negminibatch.shuffle()
        iter = 0
        #         print('Epoch: %04d' % (epoch + 1))
        epoch_val_costs.append(0)
        if (FLAGS.model_size == 'big'):
            whichbatch = minibatch
        elif (len(G.edges()) > len(Gneg.edges())):
            whichbatch = minibatch
            opwhichbatch = negminibatch
        else:
            whichbatch = negminibatch
            opwhichbatch = minibatch

        while (not whichbatch.end()):
            if (FLAGS.model_size == 'small' and opwhichbatch.end()):
                opwhichbatch.shuffle()
            # Construct feed dictionary

            feed_dict = minibatch.next_minibatch_feed_dict()
            negfeed_dict = negminibatch.next_minibatch_feed_dict()

            if (True):
                feed_dict.update({
                    placeholders['negbatch1']:
                    negfeed_dict[placeholders['batch1']]
                })
                feed_dict.update({
                    placeholders['negbatch2']:
                    negfeed_dict[placeholders['batch2']]
                })
                feed_dict.update({
                    placeholders['negbatch_size']:
                    negfeed_dict[placeholders['batch_size']]
                })
            else:
                batch1 = feed_dict[placeholders['batch1']]
                feed_dict.update({placeholders['negbatch1']: batch1})
                feed_dict.update({
                    placeholders['negbatch2']:
                    negminibatch.feed_dict_negbatch(batch1)
                })
                feed_dict.update({
                    placeholders['negbatch_size']:
                    negfeed_dict[placeholders['batch_size']]
                })

            if (Theta is not None):
                break
            if (total_steps == 0):
                print_summary(model, feed_dict, sess, Adj_mat, 'first',
                              print_flag)

            t = time.time()
            # Training step
            outs = sess.run(
                [
                    merged, model.opt_op, model.loss, model.ranks,
                    model.aff_all, model.mrr, model.outputs1, model.outputs2,
                    model.negoutputs1, model.negoutputs2
                ],
                feed_dict=feed_dict
            )  #, model.current_similarity   , model.learned_vars['theta_1']  # , model.aggregation
            train_cost = outs[2]
            train_mrr = outs[5]

            #             Z = outs[8]
            #             outs = np.concatenate((outs[6],outs[7]), axis=0)
            #             indices = np.concatenate((feed_dict[placeholders['batch1']],feed_dict[placeholders['batch2']]))
            #             Z = np.zeros((len(G.nodes()),FLAGS.dim_2*2))
            #             for node in G.nodes():
            #                 Z[node,:] = outs[int(np.argwhere(indices==node)[0]),:]#[0:len(G.nodes())] # 8
            #             print('Z shape: ', Z.shape)

            #             if train_shadow_mrr is None:
            #                 train_shadow_mrr = train_mrr #
            #             else:
            #                 train_shadow_mrr -= (1-0.99) * (train_shadow_mrr - train_mrr)
            #
            #             if iter % FLAGS.validate_iter == 0:
            #                 # Validation
            #                 sess.run(val_adj_info.op)
            #                 val_cost, ranks, val_mrr, duration = evaluate(sess, model, minibatch, size=FLAGS.validate_batch_size, negminibatch_iter=negminibatch, placeholders=placeholders)
            #                 sess.run(train_adj_info.op)
            #                 epoch_val_costs[-1] += val_cost
            #
            #             if shadow_mrr is None:
            #                 shadow_mrr = val_mrr
            #             else:
            #                 shadow_mrr -= (1-0.99) * (shadow_mrr - val_mrr)

            #             if total_steps % FLAGS.print_every == 0:
            #                 summary_writer.add_summary(outs[0], total_steps)

            # Print results
            avg_time = (avg_time * total_steps + time.time() -
                        t) / (total_steps + 1)

            if total_steps % FLAGS.print_every == 0:
                print('loss: ', outs[2])
#                 print("Iter:", '%04d' % iter,
#                       "train_loss=", "{:.5f}".format(train_cost),
#                       "train_mrr=", "{:.5f}".format(train_mrr),
#                       "train_mrr_ema=", "{:.5f}".format(train_shadow_mrr), # exponential moving average
#                       "val_loss=", "{:.5f}".format(val_cost),
#                       "val_mrr=", "{:.5f}".format(val_mrr),
#                       "val_mrr_ema=", "{:.5f}".format(shadow_mrr), # exponential moving average
#                       "time=", "{:.5f}".format(avg_time))

#             similarity_weights = outs[7]
#             [new_adj_info, new_batch_edges] = sess.run([model.adj_info, \
#                                                                 model.new_batch_edges], \
#                                                                   feed_dict=feed_dict)
#             minibatch.graph_update(new_adj_info, new_batch_edges)

            iter += 1
            total_steps += 1

            if total_steps > FLAGS.max_total_steps:
                break

        if (Theta is not None):
            break
        if total_steps > FLAGS.max_total_steps:
            break


#     print("SGD Optimization Finished!")

#     feed_dict = dict()
#     feed_dict.update({placeholders['batch_size'] : len(G.nodes())})

#     minibatch = EdgeMinibatchIterator(G,
#             id_map,
#             placeholders, batch_size = len(G.edges()),
#             max_degree = FLAGS.max_degree,
#             num_neg_samples = FLAGS.neg_sample_size,
#             context_pairs = context_pairs)
#     feed_dict = minibatch.next_minibatch_feed_dict()
#     _, Z = sess.run([merged, model.aggregation], feed_dict=feed_dict) #, model.concat   ,     aggregator_cls.vars

#     Z_tilde = np.repeat(Z, [len(G.nodes())], axis=0)
#     Z_tilde_tilde = np.tile(Z, (len(G.nodes()),1))
#     final_theta_1 = Quadratic_SDP_solver (Z_tilde, Z_tilde_tilde, FLAGS.dim_1*2, FLAGS.dim_2*2)
# #     Z_centralized = Z - np.mean(Z, axis=0)
# #     final_adj_matrix = np.abs(np.matmul(Z_centralized, np.matmul(final_theta_1, np.transpose(Z_centralized))))
#     final_adj_matrix = np.matmul(Z, np.matmul(final_theta_1, np.transpose(Z)))
#     feed_dict = minibatch.next_minibatch_feed_dict()
#     feed_dict.update({placeholders['dropout']: FLAGS.dropout})
#     feed_dict.update({placeholders['all_nodes']: all_nodes})

#     FLAGS.batch_size = len(G.edges())
#     FLAGS.negbatch_size = len(Gneg.edges())
#     FLAGS.samples_1 = len(G.nodes())
#     FLAGS.samples_2 = len(G.nodes())
#     FLAGS.negsamples_1 = len(Gneg.nodes())
#     FLAGS.negsamples_2 = len(Gneg.nodes())
#     feed_dict = minibatch.batch_feed_dict(G.edges())
#     negfeed_dict = negminibatch.batch_feed_dict(Gneg.edges())
#     feed_dict.update({placeholders['negbatch1']: negfeed_dict[placeholders['batch1']]})
#     feed_dict.update({placeholders['negbatch2']: negfeed_dict[placeholders['batch2']]})
#     feed_dict.update({placeholders['negbatch_size']: negfeed_dict[placeholders['batch_size']]})
#     if(outs is not None):
#         print('loss: ', outs[2])

    final_adj_matrix, final_theta_1, Z, loss, U = print_summary(
        model, feed_dict, sess, Adj_mat, 'last', print_flag)

    #print('Z shape: ', Z.shape)

    if FLAGS.save_embeddings:
        sess.run(val_adj_info.op)

        save_val_embeddings(sess, model, minibatch, FLAGS.validate_batch_size,
                            log_dir())

        if FLAGS.model == "n2v":
            # stopping the gradient for the already trained nodes
            train_ids = tf.constant(
                [[id_map[n]] for n in G.nodes_iter()
                 if not G.node[n]['val'] and not G.node[n]['test']],
                dtype=tf.int32)
            test_ids = tf.constant([[id_map[n]] for n in G.nodes_iter()
                                    if G.node[n]['val'] or G.node[n]['test']],
                                   dtype=tf.int32)
            update_nodes = tf.nn.embedding_lookup(model.context_embeds,
                                                  tf.squeeze(test_ids))
            no_update_nodes = tf.nn.embedding_lookup(model.context_embeds,
                                                     tf.squeeze(train_ids))
            update_nodes = tf.scatter_nd(test_ids, update_nodes,
                                         tf.shape(model.context_embeds))
            no_update_nodes = tf.stop_gradient(
                tf.scatter_nd(train_ids, no_update_nodes,
                              tf.shape(model.context_embeds)))
            model.context_embeds = update_nodes + no_update_nodes
            sess.run(model.context_embeds)

            # run random walks
            from graphsage.utils import run_random_walks
            nodes = [
                n for n in G.nodes_iter()
                if G.node[n]["val"] or G.node[n]["test"]
            ]
            start_time = time.time()
            pairs = run_random_walks(G, nodes, num_walks=50)
            walk_time = time.time() - start_time

            test_minibatch = EdgeMinibatchIterator(
                G,
                id_map,
                placeholders,
                batch_size=FLAGS.batch_size,
                max_degree=FLAGS.max_degree,
                num_neg_samples=FLAGS.neg_sample_size,
                context_pairs=pairs,
                n2v_retrain=True,
                fixed_n2v=True)

            start_time = time.time()
            print("Doing test training for n2v.")
            test_steps = 0
            for epoch in range(FLAGS.n2v_test_epochs):
                test_minibatch.shuffle()
                while not test_minibatch.end():
                    feed_dict = test_minibatch.next_minibatch_feed_dict()
                    feed_dict.update({placeholders['dropout']: FLAGS.dropout})
                    outs = sess.run([
                        model.opt_op, model.loss, model.ranks, model.aff_all,
                        model.mrr, model.outputs1
                    ],
                                    feed_dict=feed_dict)
                    if test_steps % FLAGS.print_every == 0:
                        print("Iter:", '%04d' % test_steps, "train_loss=",
                              "{:.5f}".format(outs[1]), "train_mrr=",
                              "{:.5f}".format(outs[-2]))
                    test_steps += 1
            train_time = time.time() - start_time
            save_val_embeddings(sess,
                                model,
                                minibatch,
                                FLAGS.validate_batch_size,
                                log_dir(),
                                mod="-test")
            print("Total time: ", train_time + walk_time)
            print("Walk time: ", walk_time)
            print("Train time: ", train_time)
    #del adj_info_ph, adj_info,placeholders
    return final_adj_matrix, G, final_theta_1, Z, loss, U  #, learned_vars
def train(train_data, NextDoorKHopSampler, test_data=None):
    G = train_data[0]
    features = train_data[1]
    id_map = train_data[2]
    class_map = train_data[4]
    if isinstance(list(class_map.values())[0], list):
        num_classes = len(list(class_map.values())[0])
    else:
        num_classes = len(set(class_map.values()))

    if not features is None:
        # pad with dummy zero vector
        features = np.vstack([features, np.zeros((features.shape[1],))])

    context_pairs = train_data[3] if FLAGS.random_context else None
    placeholders = construct_placeholders(num_classes)

    layer_infos_top_down = [SAGEInfo("node", None, FLAGS.samples_1, FLAGS.dim_1),
                   SAGEInfo("node", None, FLAGS.samples_2, FLAGS.dim_2)]

    minibatch = NodeMinibatchIteratorWithKHop(G,
                                      id_map,
                                      placeholders,
                                      class_map,
                                      num_classes,
                                      layer_infos_top_down,
                                      batch_size=FLAGS.batch_size,
                                      max_degree=FLAGS.max_degree,
                                      context_pairs=context_pairs)

    adj_info_ph = tf.placeholder(tf.int32, shape=minibatch.adj.shape)
    adj_info = tf.Variable(adj_info_ph, trainable=False, name="adj_info")
    sampler = UniformNeighborSampler(adj_info)
    layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                       SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)]
    model = SampledSupervisedGraphsage(num_classes, placeholders,
                                           features,
                                           adj_info,
                                           minibatch.deg,
                                           layer_infos=layer_infos,
                                           model_size=FLAGS.model_size,
                                           sigmoid_loss=FLAGS.sigmoid,
                                           identity_dim=FLAGS.identity_dim,
                                           logging=True)
    config = tf.ConfigProto(log_device_placement=FLAGS.log_device_placement)
    config.gpu_options.allow_growth = True
    # config.gpu_options.per_process_gpu_memory_fraction = GPU_MEM_FRACTION
    config.allow_soft_placement = True

    # Initialize session
    sess = tf.Session(config=config)
    merged = tf.summary.merge_all()
    summary_writer = tf.summary.FileWriter(log_dir(), sess.graph)

    # Init variables
    sess.run(tf.global_variables_initializer(), feed_dict={adj_info_ph: minibatch.adj})
    lib = ctypes.CDLL("./KHopSamplingPy3.so")
    print("NextDoorKHopSampler.finalSampleLength() ", NextDoorKHopSampler.finalSampleLength())
    lib.finalSamplesArray.restype = ndpointer(dtype=ctypes.c_int, shape=(min(NextDoorKHopSampler.finalSampleLength(), 2**28)))

    # Train model

    total_steps = 0
    avg_time = 0.0
    epoch_val_costs = []

    train_adj_info = tf.assign(adj_info, minibatch.adj)
    val_adj_info = tf.assign(adj_info, minibatch.test_adj)
    no_epochs = FLAGS.epochs
    total_epoch_time = 0
    for epoch in range(FLAGS.epochs):
        minibatch.shuffle()
        s_time = time.time()
        NextDoorKHopSampler.sample()
        finalSamples = lib.finalSamplesArray()
        minibatch.nextdoorFinalSamples = finalSamples
        s_time = time.time()
        iter = 0
        print('Epoch: %04d' % (epoch + 1))
        epoch_val_costs.append(0)
        while not minibatch.end():
            # Construct feed dictionary
            feed_dict, labels = minibatch.next_minibatch_feed_dict()
            feed_dict.update({placeholders['dropout']: FLAGS.dropout})

            t = time.time()
            # Training step
            # sess.run([model.preds],feed_dict=feed_dict)
            # assert (False)
            outs = sess.run([merged, model.opt_op, model.loss, model.preds], feed_dict=feed_dict)
            train_cost = outs[2]

            if iter % FLAGS.validate_iter == 0:
                # Validation
                sess.run(val_adj_info.op)
                if FLAGS.validate_batch_size == -1:
                    val_cost, val_f1_mic, val_f1_mac, duration = incremental_evaluate(sess, model, minibatch,
                                                                                      FLAGS.batch_size)
                else:
                    val_cost, val_f1_mic, val_f1_mac, duration = evaluate(sess, model, minibatch,
                                                                          FLAGS.validate_batch_size)
                sess.run(train_adj_info.op)
                epoch_val_costs[-1] += val_cost

            if total_steps % FLAGS.print_every == 0:
                summary_writer.add_summary(outs[0], total_steps)

            # Print results
            avg_time = (avg_time * total_steps + time.time() - t) / (total_steps + 1)

            if total_steps % FLAGS.print_every == 0:
                train_f1_mic, train_f1_mac = calc_f1(labels, outs[-1])
                print("Iter:", '%04d' % iter,
                      "train_loss=", "{:.5f}".format(train_cost),
                      "train_f1_mic=", "{:.5f}".format(train_f1_mic),
                      "train_f1_mac=", "{:.5f}".format(train_f1_mac),
                      "val_loss=", "{:.5f}".format(val_cost),
                      "val_f1_mic=", "{:.5f}".format(val_f1_mic),
                      "val_f1_mac=", "{:.5f}".format(val_f1_mac),
                      "time=", "{:.5f}".format(avg_time))

            iter += 1
            total_steps += 1

            if total_steps > FLAGS.max_total_steps:
                break

        total_epoch_time = total_epoch_time + time.time() - s_time
        if total_steps > FLAGS.max_total_steps:
            break

    print("Optimization Finished!")
    sess.run(val_adj_info.op)
    val_cost, val_f1_mic, val_f1_mac, duration = incremental_evaluate(sess, model, minibatch, FLAGS.batch_size)
    print("Full validation stats:",
          "loss=", "{:.5f}".format(val_cost),
          "f1_micro=", "{:.5f}".format(val_f1_mic),
          "f1_macro=", "{:.5f}".format(val_f1_mac),
          "time=", "{:.5f}".format(duration))
    with open(log_dir() + "val_stats.txt", "w") as fp:
        fp.write("loss={:.5f} f1_micro={:.5f} f1_macro={:.5f} time={:.5f}".
                 format(val_cost, val_f1_mic, val_f1_mac, duration))

    print("Writing test set stats to file (don't peak!)")
    val_cost, val_f1_mic, val_f1_mac, duration = incremental_evaluate(sess, model, minibatch, FLAGS.batch_size,
                                                                      test=True)
    with open(log_dir() + "test_stats.txt", "w") as fp:
        fp.write("loss={:.5f} f1_micro={:.5f} f1_macro={:.5f}".
                 format(val_cost, val_f1_mic, val_f1_mac))
    sess.close()
    return total_epoch_time / no_epochs
Пример #13
0
def train(train_data, test_data=None, sampler_name='Uniform'):

    G = train_data[0]
    features = train_data[1]
    id_map = train_data[2]
    class_map = train_data[4]

    if isinstance(list(class_map.values())[0], list):
        num_classes = len(list(class_map.values())[0])
    else:
        num_classes = len(set(class_map.values()))

    if not features is None:
        # pad with dummy zero vector
        features = np.vstack([features, np.zeros((features.shape[1], ))])

    context_pairs = train_data[3] if FLAGS.random_context else None
    placeholders = construct_placeholders(num_classes)
    minibatch = NodeMinibatchIterator(G,
                                      id_map,
                                      placeholders,
                                      class_map,
                                      num_classes,
                                      batch_size=FLAGS.batch_size,
                                      max_degree=FLAGS.max_degree,
                                      context_pairs=context_pairs)
    adj_info_ph = tf.placeholder(tf.int32, shape=minibatch.adj.shape)
    adj_info = tf.Variable(adj_info_ph, trainable=False, name="adj_info")

    adj_shape = adj_info.get_shape().as_list()

    #    loss_node = tf.SparseTensor(indices=np.empty((0,2), dtype=np.int64), values=[], dense_shape=[adj_shape[0], adj_shape[0]])
    #    loss_node_count = tf.SparseTensor(indices=np.empty((0,2), dtype=np.int64), values=[], dense_shape=[adj_shape[0], adj_shape[0]])
    #
    # newly added for storing cost in each adj cell
    #    loss_node = tf.Variable(tf.zeros([minibatch.adj.shape[0], minibatch.adj.shape[0]]), trainable=False, name="loss_node", dtype=tf.float32)
    #    loss_node_count = tf.Variable(tf.zeros([minibatch.adj.shape[0], minibatch.adj.shape[0]]), trainable=False, name="loss_node_count", dtype=tf.float32)

    if FLAGS.model == 'mean_concat':
        # Create model

        if sampler_name == 'Uniform':
            sampler = UniformNeighborSampler(adj_info)
        elif sampler_name == 'ML':
            sampler = MLNeighborSampler(adj_info, features)
        elif sampler_name == 'FastML':
            sampler = FastMLNeighborSampler(adj_info, features)

        if FLAGS.samples_3 != 0:
            layer_infos = [
                SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2),
                SAGEInfo("node", sampler, FLAGS.samples_3, FLAGS.dim_3)
            ]
        elif FLAGS.samples_2 != 0:
            layer_infos = [
                SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
            ]
        else:
            layer_infos = [
                SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1)
            ]
        '''        
        ### 3 layer test
        layer_infos = [SAGEInfo("node", sampler, 50, FLAGS.dim_2),
                                SAGEInfo("node", sampler, 25, FLAGS.dim_2),
                                SAGEInfo("node", sampler, 10, FLAGS.dim_2)]
 
        '''

        # modified
        model = SupervisedGraphsage(
            num_classes,
            placeholders,
            features,
            adj_info,
            #loss_node,
            #loss_node_count,
            minibatch.deg,
            layer_infos,
            concat=True,
            model_size=FLAGS.model_size,
            sigmoid_loss=FLAGS.sigmoid,
            identity_dim=FLAGS.identity_dim,
            logging=True)
#
#        model = SupervisedGraphsage(num_classes, placeholders,
#                                     features,
#                                     adj_info,
#                                     minibatch.deg,
#                                     layer_infos,
#                                     model_size=FLAGS.model_size,
#                                     sigmoid_loss = FLAGS.sigmoid,
#                                     identity_dim = FLAGS.identity_dim,
#                                     logging=True)

    elif FLAGS.model == 'mean_add':
        # Create model

        if sampler_name == 'Uniform':
            sampler = UniformNeighborSampler(adj_info)
        elif sampler_name == 'ML':
            sampler = MLNeighborSampler(adj_info, features)
        elif sampler_name == 'FastML':
            sampler = FastMLNeighborSampler(adj_info, features)

        if FLAGS.samples_3 != 0:
            layer_infos = [
                SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2),
                SAGEInfo("node", sampler, FLAGS.samples_3, FLAGS.dim_3)
            ]
        elif FLAGS.samples_2 != 0:
            layer_infos = [
                SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
            ]
        else:
            layer_infos = [
                SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1)
            ]
        '''        
        ### 3 layer test
        layer_infos = [SAGEInfo("node", sampler, 50, FLAGS.dim_2),
                                SAGEInfo("node", sampler, 25, FLAGS.dim_2),
                                SAGEInfo("node", sampler, 10, FLAGS.dim_2)]
 
        '''

        # modified
        model = SupervisedGraphsage(
            num_classes,
            placeholders,
            features,
            adj_info,
            #loss_node,
            #loss_node_count,
            minibatch.deg,
            layer_infos,
            concat=False,
            model_size=FLAGS.model_size,
            sigmoid_loss=FLAGS.sigmoid,
            identity_dim=FLAGS.identity_dim,
            logging=True)
#
#        model = SupervisedGraphsage(num_classes, placeholders,
#                                     features,
#                                     adj_info,
#                                     minibatch.deg,
#                                     layer_infos,
#                                     model_size=FLAGS.model_size,
#                                     sigmoid_loss = FLAGS.sigmoid,
#                                     identity_dim = FLAGS.identity_dim,
#                                     logging=True)

    elif FLAGS.model == 'LRmean_add':
        # Create model

        if sampler_name == 'Uniform':
            sampler = UniformNeighborSampler(adj_info)
        elif sampler_name == 'ML':
            sampler = MLNeighborSampler(adj_info, features)
        elif sampler_name == 'FastML':
            sampler = FastMLNeighborSampler(adj_info, features)

        if FLAGS.samples_3 != 0:
            layer_infos = [
                SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2),
                SAGEInfo("node", sampler, FLAGS.samples_3, FLAGS.dim_3)
            ]
        elif FLAGS.samples_2 != 0:
            layer_infos = [
                SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
            ]
        else:
            layer_infos = [
                SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1)
            ]
        '''        
        ### 3 layer test
        layer_infos = [SAGEInfo("node", sampler, 50, FLAGS.dim_2),
                                SAGEInfo("node", sampler, 25, FLAGS.dim_2),
                                SAGEInfo("node", sampler, 10, FLAGS.dim_2)]
 
        '''

        # modified
        model = SupervisedGraphsage(
            num_classes,
            placeholders,
            features,
            adj_info,
            #loss_node,
            #loss_node_count,
            minibatch.deg,
            layer_infos,
            aggregator_type="LRmean",
            concat=False,
            model_size=FLAGS.model_size,
            sigmoid_loss=FLAGS.sigmoid,
            identity_dim=FLAGS.identity_dim,
            logging=True)
#
#        model = SupervisedGraphsage(num_classes, placeholders,
#                                     features,
#                                     adj_info,
#                                     minibatch.deg,
#                                     layer_infos,
#                                     model_size=FLAGS.model_size,
#                                     sigmoid_loss = FLAGS.sigmoid,
#                                     identity_dim = FLAGS.identity_dim,
#                                     logging=True)

    elif FLAGS.model == 'logicmean':
        # Create model

        if sampler_name == 'Uniform':
            sampler = UniformNeighborSampler(adj_info)
        elif sampler_name == 'ML':
            sampler = MLNeighborSampler(adj_info, features)
        elif sampler_name == 'FastML':
            sampler = FastMLNeighborSampler(adj_info, features)

        if FLAGS.samples_3 != 0:
            layer_infos = [
                SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2),
                SAGEInfo("node", sampler, FLAGS.samples_3, FLAGS.dim_2)
            ]
        elif FLAGS.samples_2 != 0:
            layer_infos = [
                SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
            ]
        else:
            layer_infos = [
                SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1)
            ]
        '''        
        ### 3 layer test
        layer_infos = [SAGEInfo("node", sampler, 50, FLAGS.dim_2),
                                SAGEInfo("node", sampler, 25, FLAGS.dim_2),
                                SAGEInfo("node", sampler, 10, FLAGS.dim_2)]
 
        '''

        # modified
        model = SupervisedGraphsage(
            num_classes,
            placeholders,
            features,
            adj_info,
            #loss_node,
            #loss_node_count,
            minibatch.deg,
            layer_infos,
            aggregator_type='logicmean',
            concat=True,
            model_size=FLAGS.model_size,
            sigmoid_loss=FLAGS.sigmoid,
            identity_dim=FLAGS.identity_dim,
            logging=True)
#
    elif FLAGS.model == 'attmean':
        # Create model

        if sampler_name == 'Uniform':
            sampler = UniformNeighborSampler(adj_info)
        elif sampler_name == 'ML':
            sampler = MLNeighborSampler(adj_info, features)
        elif sampler_name == 'FastML':
            sampler = FastMLNeighborSampler(adj_info, features)

        if FLAGS.samples_3 != 0:
            layer_infos = [
                SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2),
                SAGEInfo("node", sampler, FLAGS.samples_3, FLAGS.dim_2)
            ]
        elif FLAGS.samples_2 != 0:
            layer_infos = [
                SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
            ]
        else:
            layer_infos = [
                SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1)
            ]
        '''        
        ### 3 layer test
        layer_infos = [SAGEInfo("node", sampler, 50, FLAGS.dim_2),
                                SAGEInfo("node", sampler, 25, FLAGS.dim_2),
                                SAGEInfo("node", sampler, 10, FLAGS.dim_2)]
 
        '''

        # modified
        model = SupervisedGraphsage(
            num_classes,
            placeholders,
            features,
            adj_info,
            #loss_node,
            #loss_node_count,
            minibatch.deg,
            layer_infos,
            aggregator_type='attmean',
            model_size=FLAGS.model_size,
            sigmoid_loss=FLAGS.sigmoid,
            identity_dim=FLAGS.identity_dim,
            logging=True)
#
#        model = SupervisedGraphsage(num_classes, placeholders,
#                                     features,
#                                     adj_info,
#                                     minibatch.deg,
#                                     layer_infos,
#                                     model_size=FLAGS.model_size,
#                                     sigmoid_loss = FLAGS.sigmoid,
#                                     identity_dim = FLAGS.identity_dim,
#                                     logging=True)

    elif FLAGS.model == 'gcn':

        if sampler_name == 'Uniform':
            sampler = UniformNeighborSampler(adj_info)
        elif sampler_name == 'ML':
            sampler = MLNeighborSampler(adj_info, features)
        elif sampler_name == 'FastML':
            sampler = FastMLNeighborSampler(adj_info, features)

        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, 2 * FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, 2 * FLAGS.dim_2)
        ]

        model = SupervisedGraphsage(num_classes,
                                    placeholders,
                                    features,
                                    adj_info,
                                    minibatch.deg,
                                    layer_infos=layer_infos,
                                    aggregator_type="gcn",
                                    model_size=FLAGS.model_size,
                                    concat=False,
                                    sigmoid_loss=FLAGS.sigmoid,
                                    identity_dim=FLAGS.identity_dim,
                                    logging=True)

    elif FLAGS.model == 'graphsage_seq':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SupervisedGraphsage(num_classes,
                                    placeholders,
                                    features,
                                    adj_info,
                                    minibatch.deg,
                                    layer_infos=layer_infos,
                                    aggregator_type="seq",
                                    model_size=FLAGS.model_size,
                                    sigmoid_loss=FLAGS.sigmoid,
                                    identity_dim=FLAGS.identity_dim,
                                    logging=True)

    elif FLAGS.model == 'graphsage_maxpool':

        if sampler_name == 'Uniform':
            sampler = UniformNeighborSampler(adj_info)
        elif sampler_name == 'ML':
            sampler = MLNeighborSampler(adj_info, features)
        elif sampler_name == 'FastML':
            sampler = FastMLNeighborSampler(adj_info, features)

        #sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SupervisedGraphsage(num_classes,
                                    placeholders,
                                    features,
                                    adj_info,
                                    minibatch.deg,
                                    layer_infos=layer_infos,
                                    aggregator_type="maxpool",
                                    model_size=FLAGS.model_size,
                                    sigmoid_loss=FLAGS.sigmoid,
                                    identity_dim=FLAGS.identity_dim,
                                    logging=True)

    elif FLAGS.model == 'graphsage_meanpool':

        if sampler_name == 'Uniform':
            sampler = UniformNeighborSampler(adj_info)
        elif sampler_name == 'ML':
            sampler = MLNeighborSampler(adj_info, features)
        elif sampler_name == 'FastML':
            sampler = FastMLNeighborSampler(adj_info, features)

        #sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SupervisedGraphsage(num_classes,
                                    placeholders,
                                    features,
                                    adj_info,
                                    minibatch.deg,
                                    layer_infos=layer_infos,
                                    aggregator_type="meanpool",
                                    model_size=FLAGS.model_size,
                                    sigmoid_loss=FLAGS.sigmoid,
                                    identity_dim=FLAGS.identity_dim,
                                    logging=True)

    else:
        raise Exception('Error: model name unrecognized.')

    config = tf.ConfigProto(log_device_placement=FLAGS.log_device_placement)
    config.gpu_options.allow_growth = True
    #config.gpu_options.per_process_gpu_memory_fraction = GPU_MEM_FRACTION
    config.allow_soft_placement = True

    # Initialize session
    sess = tf.Session(config=config)
    merged = tf.summary.merge_all()
    summary_writer = tf.summary.FileWriter(log_dir(sampler_name), sess.graph)

    # Save model
    saver = tf.train.Saver()
    model_path = './model/' + FLAGS.train_prefix.split(
        '/')[-1] + '-' + FLAGS.model_prefix + '-' + sampler_name
    model_path += "/{model:s}_{model_size:s}_{lr:0.4f}/".format(
        model=FLAGS.model, model_size=FLAGS.model_size, lr=FLAGS.learning_rate)

    if not os.path.exists(model_path):
        os.makedirs(model_path)

    # Init variables
    sess.run(tf.global_variables_initializer(),
             feed_dict={adj_info_ph: minibatch.adj})

    # Restore params of ML sampler model
    if sampler_name == 'ML' or sampler_name == 'FastML':
        sampler_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                         scope="MLsampler")
        #pdb.set_trace()
        saver_sampler = tf.train.Saver(var_list=sampler_vars)
        sampler_model_path = './model/MLsampler-' + FLAGS.train_prefix.split(
            '/')[-1] + '-' + FLAGS.model_prefix
        sampler_model_path += "/{model:s}_{model_size:s}_{lr:0.4f}/".format(
            model=FLAGS.model,
            model_size=FLAGS.model_size,
            lr=FLAGS.learning_rate)

        saver_sampler.restore(sess, sampler_model_path + 'model.ckpt')

    # Train model

    total_steps = 0
    avg_time = 0.0
    epoch_val_costs = []

    train_adj_info = tf.assign(adj_info, minibatch.adj)
    val_adj_info = tf.assign(adj_info, minibatch.test_adj)

    val_cost_ = []
    val_f1_mic_ = []
    val_f1_mac_ = []
    duration_ = []

    ln_acc = sparse.csr_matrix((adj_shape[0], adj_shape[0]), dtype=np.float32)
    lnc_acc = sparse.csr_matrix((adj_shape[0], adj_shape[0]), dtype=np.int32)

    ln_acc = ln_acc.tolil()
    lnc_acc = lnc_acc.tolil()
    #
    #    ln_acc = np.zeros([adj_shape[0], adj_shape[0]])
    #    lnc_acc = np.zeros([adj_shape[0], adj_shape[0]])

    for epoch in range(FLAGS.epochs):
        minibatch.shuffle()

        iter = 0
        print('Epoch: %04d' % (epoch + 1))
        epoch_val_costs.append(0)

        #for j in range(2):
        while not minibatch.end():
            # Construct feed dictionary
            feed_dict, labels = minibatch.next_minibatch_feed_dict()

            if feed_dict.values()[0] != FLAGS.batch_size:
                break

            feed_dict.update({placeholders['dropout']: FLAGS.dropout})

            t = time.time()
            # Training step
            #outs = sess.run([merged, model.opt_op, model.loss, model.preds], feed_dict=feed_dict)

            outs = sess.run([
                merged, model.opt_op, model.loss, model.preds, model.loss_node,
                model.loss_node_count, model.out_mean
            ],
                            feed_dict=feed_dict)
            train_cost = outs[2]

            if iter % FLAGS.validate_iter == 0:
                # Validation
                sess.run(val_adj_info.op)
                if FLAGS.validate_batch_size == -1:
                    val_cost, val_f1_mic, val_f1_mac, duration = incremental_evaluate(
                        sess, model, minibatch, FLAGS.batch_size)
                else:
                    val_cost, val_f1_mic, val_f1_mac, duration = evaluate(
                        sess, model, minibatch, FLAGS.validate_batch_size)

                # accumulate val results
                val_cost_.append(val_cost)
                val_f1_mic_.append(val_f1_mic)
                val_f1_mac_.append(val_f1_mac)
                duration_.append(duration)

                #
                sess.run(train_adj_info.op)
                epoch_val_costs[-1] += val_cost

            if total_steps % FLAGS.print_every == 0:
                summary_writer.add_summary(outs[0], total_steps)

            # Print results
            avg_time = (avg_time * total_steps + time.time() -
                        t) / (total_steps + 1)

            # loss_node
            #import pdb
            #pdb.set_trace()

            #            if epoch > 0.7*FLAGS.epochs:
            #                ln = outs[-2].values
            #                ln_idx = outs[-2].indices
            #                ln_acc[ln_idx[:,0], ln_idx[:,1]] += ln
            #
            #
            #                lnc = outs[-1].values
            #                lnc_idx = outs[-1].indices
            #                lnc_acc[lnc_idx[:,0], lnc_idx[:,1]] += lnc

            ln = outs[4].values
            ln_idx = outs[4].indices
            ln_acc[ln_idx[:, 0], ln_idx[:, 1]] += ln

            lnc = outs[5].values
            lnc_idx = outs[5].indices
            lnc_acc[lnc_idx[:, 0], lnc_idx[:, 1]] += lnc

            #pdb.set_trace()
            #idx = np.where(lnc_acc != 0)
            #loss_node_mean = (ln_acc[idx[0], idx[1]]).mean()
            #loss_node_count_mean = (lnc_acc[idx[0], idx[1]]).mean()

            if total_steps % FLAGS.print_every == 0:
                train_f1_mic, train_f1_mac = calc_f1(labels, outs[3])
                print(
                    "Iter:",
                    '%04d' % iter,
                    "train_loss=",
                    "{:.5f}".format(train_cost),
                    "train_f1_mic=",
                    "{:.5f}".format(train_f1_mic),
                    "train_f1_mac=",
                    "{:.5f}".format(train_f1_mac),
                    "val_loss=",
                    "{:.5f}".format(val_cost),
                    "val_f1_mic=",
                    "{:.5f}".format(val_f1_mic),
                    "val_f1_mac=",
                    "{:.5f}".format(val_f1_mac),
                    #"loss_node=", "{:.5f}".format(loss_node_mean),
                    #"loss_node_count=", "{:.5f}".format(loss_node_count_mean),
                    "time=",
                    "{:.5f}".format(avg_time))

            iter += 1
            total_steps += 1

            if total_steps > FLAGS.max_total_steps:
                break

        if total_steps > FLAGS.max_total_steps:
            break

    # Save model
    save_path = saver.save(sess, model_path + 'model.ckpt')
    print('model is saved at %s' % save_path)

    # Save loss node and count
    loss_node_path = './loss_node/' + FLAGS.train_prefix.split(
        '/')[-1] + '-' + FLAGS.model_prefix + '-' + sampler_name
    loss_node_path += "/{model:s}_{model_size:s}_{lr:0.4f}/".format(
        model=FLAGS.model, model_size=FLAGS.model_size, lr=FLAGS.learning_rate)
    if not os.path.exists(loss_node_path):
        os.makedirs(loss_node_path)

    loss_node = sparse.save_npz(loss_node_path + 'loss_node.npz',
                                sparse.csr_matrix(ln_acc))
    loss_node_count = sparse.save_npz(loss_node_path + 'loss_node_count.npz',
                                      sparse.csr_matrix(lnc_acc))
    print('loss and count per node is saved at %s' % loss_node_path)

    #    # save images of loss node and count
    #    plt.imsave(loss_node_path + 'loss_node_mean.png', np.uint8(np.round(np.divide(ln_acc.todense()[:1024,:1024], lnc_acc.todense()[:1024,:1024]+1e-10))), cmap='jet', vmin=0, vmax=255)
    #    plt.imsave(loss_node_path + 'loss_node_count.png', np.uint8(lnc_acc.todense()[:1024,:1024]), cmap='jet', vmin=0, vmax=255)
    #

    print("Validation per epoch in training")
    for ep in range(FLAGS.epochs):
        print("Epoch: %04d" % ep, " val_cost={:.5f}".format(val_cost_[ep]),
              " val_f1_mic={:.5f}".format(val_f1_mic_[ep]),
              " val_f1_mac={:.5f}".format(val_f1_mac_[ep]),
              " duration={:.5f}".format(duration_[ep]))

    print("Optimization Finished!")
    sess.run(val_adj_info.op)

    # full validation
    val_cost_ = []
    val_f1_mic_ = []
    val_f1_mac_ = []
    duration_ = []
    for iter in range(10):
        val_cost, val_f1_mic, val_f1_mac, duration = incremental_evaluate(
            sess, model, minibatch, FLAGS.batch_size)
        print("Full validation stats:", "loss=", "{:.5f}".format(val_cost),
              "f1_micro=", "{:.5f}".format(val_f1_mic), "f1_macro=",
              "{:.5f}".format(val_f1_mac), "time=", "{:.5f}".format(duration))

        val_cost_.append(val_cost)
        val_f1_mic_.append(val_f1_mic)
        val_f1_mac_.append(val_f1_mac)
        duration_.append(duration)

    print("mean: loss={:.5f} f1_micro={:.5f} f1_macro={:.5f} time={:.5f}\n".
          format(np.mean(val_cost_), np.mean(val_f1_mic_),
                 np.mean(val_f1_mac_), np.mean(duration_)))

    # write validation results
    with open(log_dir(sampler_name) + "val_stats.txt", "w") as fp:
        for iter in range(10):
            fp.write(
                "loss={:.5f} f1_micro={:.5f} f1_macro={:.5f} time={:.5f}\n".
                format(val_cost_[iter], val_f1_mic_[iter], val_f1_mac_[iter],
                       duration_[iter]))

        fp.write(
            "mean: loss={:.5f} f1_micro={:.5f} f1_macro={:.5f} time={:.5f}\n".
            format(np.mean(val_cost_), np.mean(val_f1_mic_),
                   np.mean(val_f1_mac_), np.mean(duration_)))
        fp.write(
            "variance: loss={:.5f} f1_micro={:.5f} f1_macro={:.5f} time={:.5f}\n"
            .format(np.var(val_cost_), np.var(val_f1_mic_),
                    np.var(val_f1_mac_), np.var(duration_)))

    # test
    val_cost_ = []
    val_f1_mic_ = []
    val_f1_mac_ = []
    duration_ = []

    print("Writing test set stats to file (don't peak!)")

    # timeline
    if FLAGS.timeline == True:
        run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
        run_metadata = tf.RunMetadata()
    else:
        run_options = None
        run_metadata = None

    for iter in range(10):

        val_cost, val_f1_mic, val_f1_mac, duration = incremental_evaluate(
            sess,
            model,
            minibatch,
            FLAGS.batch_size,
            run_options,
            run_metadata,
            test=True)

        #val_cost, val_f1_mic, val_f1_mac, duration = incremental_evaluate(sess, model, minibatch, FLAGS.batch_size, test=True)
        print("Full validation stats:", "loss=", "{:.5f}".format(val_cost),
              "f1_micro=", "{:.5f}".format(val_f1_mic), "f1_macro=",
              "{:.5f}".format(val_f1_mac), "time=", "{:.5f}".format(duration))

        val_cost_.append(val_cost)
        val_f1_mic_.append(val_f1_mic)
        val_f1_mac_.append(val_f1_mac)
        duration_.append(duration)

    print("mean: loss={:.5f} f1_micro={:.5f} f1_macro={:.5f} time={:.5f}\n".
          format(np.mean(val_cost_), np.mean(val_f1_mic_),
                 np.mean(val_f1_mac_), np.mean(duration_)))

    # write test results
    with open(log_dir(sampler_name) + "test_stats.txt", "w") as fp:
        for iter in range(10):
            fp.write(
                "loss={:.5f} f1_micro={:.5f} f1_macro={:.5f} time={:.5f}\n".
                format(val_cost_[iter], val_f1_mic_[iter], val_f1_mac_[iter],
                       duration_[iter]))

        fp.write(
            "mean: loss={:.5f} f1_micro={:.5f} f1_macro={:.5f} time={:.5f}\n".
            format(np.mean(val_cost_), np.mean(val_f1_mic_),
                   np.mean(val_f1_mac_), np.mean(duration_)))
        fp.write(
            "variance: loss={:.5f} f1_micro={:.5f} f1_macro={:.5f} time={:.5f}\n"
            .format(np.var(val_cost_), np.var(val_f1_mic_),
                    np.var(val_f1_mac_), np.var(duration_)))

    # create timeline object, and write it to a json
    if FLAGS.timeline == True:
        tl = timeline.Timeline(run_metadata.step_stats)
        ctf = tl.generate_chrome_trace_format(show_memory=True)
        with open(log_dir(sampler_name) + 'timeline.json', 'w') as f:
            print('timeline written at %s' %
                  (log_dir(sampler_name) + 'timelnie.json'))
            f.write(ctf)

    sess.close()
    tf.reset_default_graph()
def train(train_data, test_data=None):
    # return G, feats, id_map, walks, class_map
    G = train_data[0]
    features = train_data[1]
    id_map = train_data[2]
    class_map = train_data[4]

    #  获取类别个数
    if isinstance(list(class_map.values())[0], list):
        num_classes = len(list(class_map.values())[0])
    else:
        num_classes = len(set(class_map.values()))

    #  在预训练特征 后 加入一行0矩阵 , 用于 wx+b 中 与b相加
    if not features is None:
        # pad with dummy zero vector
        features = np.vstack([features,
                              np.zeros((features.shape[1], ))])  # 垂直堆叠

    # 初始化随机游走序列
    context_pairs = train_data[
        3] if FLAGS.random_context else None  # 先执行中间的If 如果返回True执行左边 否右边

    # 初始化placeholder ,包括label、batch:就是构造那些不需要训练的输入输出节点
    placeholders = construct_placeholders(num_classes)

    # 初始化NodeMinibatchIterator: 初始化训练集等
    minibatch = NodeMinibatchIterator(G,
                                      id_map,
                                      placeholders,
                                      class_map,
                                      num_classes,
                                      batch_size=FLAGS.batch_size,
                                      max_degree=FLAGS.max_degree,
                                      context_pairs=context_pairs)

    adj_info_ph = tf.placeholder(tf.int32, shape=minibatch.adj.shape)
    adj_info = tf.Variable(adj_info_ph, trainable=False,
                           name="adj_info")  # 转化为无需训练的张量,这一步的目的比较迷惑

    if FLAGS.model == 'graphsage_mean':
        # Create model
        sampler = UniformNeighborSampler(adj_info)  # 初始化随机获取邻居节点的采样器
        if FLAGS.samples_3 != 0:
            layer_infos = [
                SAGEInfo(
                    "node", sampler, FLAGS.samples_1,
                    FLAGS.dim_1),  # FLAGS.samples_1 样本数量  FLAGS.dim_1 隐藏层的输出维度
                SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2),
                SAGEInfo("node", sampler, FLAGS.samples_3, FLAGS.dim_2)
            ]
        elif FLAGS.samples_2 != 0:
            layer_infos = [
                SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
            ]
        else:
            layer_infos = [
                SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1)
            ]

        model = SupervisedGraphsage(num_classes,
                                    placeholders,
                                    features,
                                    adj_info,
                                    minibatch.deg,
                                    layer_infos,
                                    model_size=FLAGS.model_size,
                                    sigmoid_loss=FLAGS.sigmoid,
                                    identity_dim=FLAGS.identity_dim,
                                    logging=True)
    elif FLAGS.model == 'gcn':
        # Create model
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, 2 * FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, 2 * FLAGS.dim_2)
        ]

        model = SupervisedGraphsage(num_classes,
                                    placeholders,
                                    features,
                                    adj_info,
                                    minibatch.deg,
                                    layer_infos=layer_infos,
                                    aggregator_type="gcn",
                                    model_size=FLAGS.model_size,
                                    concat=False,
                                    sigmoid_loss=FLAGS.sigmoid,
                                    identity_dim=FLAGS.identity_dim,
                                    logging=True)

    elif FLAGS.model == 'graphsage_seq':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SupervisedGraphsage(num_classes,
                                    placeholders,
                                    features,
                                    adj_info,
                                    minibatch.deg,
                                    layer_infos=layer_infos,
                                    aggregator_type="seq",
                                    model_size=FLAGS.model_size,
                                    sigmoid_loss=FLAGS.sigmoid,
                                    identity_dim=FLAGS.identity_dim,
                                    logging=True)

    elif FLAGS.model == 'graphsage_maxpool':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SupervisedGraphsage(num_classes,
                                    placeholders,
                                    features,
                                    adj_info,
                                    minibatch.deg,
                                    layer_infos=layer_infos,
                                    aggregator_type="maxpool",
                                    model_size=FLAGS.model_size,
                                    sigmoid_loss=FLAGS.sigmoid,
                                    identity_dim=FLAGS.identity_dim,
                                    logging=True)

    elif FLAGS.model == 'graphsage_meanpool':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SupervisedGraphsage(num_classes,
                                    placeholders,
                                    features,
                                    adj_info,
                                    minibatch.deg,
                                    layer_infos=layer_infos,
                                    aggregator_type="meanpool",
                                    model_size=FLAGS.model_size,
                                    sigmoid_loss=FLAGS.sigmoid,
                                    identity_dim=FLAGS.identity_dim,
                                    logging=True)

    else:
        raise Exception('Error: model name unrecognized.')

    config = tf.ConfigProto(log_device_placement=FLAGS.log_device_placement)
    config.gpu_options.allow_growth = True
    #config.gpu_options.per_process_gpu_memory_fraction = GPU_MEM_FRACTION
    config.allow_soft_placement = True

    # Initialize session
    sess = tf.Session(config=config)
    merged = tf.summary.merge_all()
    summary_writer = tf.summary.FileWriter(log_dir(), sess.graph)  # 指定文件来保存图

    # Init variables
    sess.run(tf.global_variables_initializer(),
             feed_dict={adj_info_ph:
                        minibatch.adj})  # ph的意思是?之前说过是相位  含义:训练集的邻接节点

    # Train model

    total_steps = 0
    avg_time = 0.0
    epoch_val_costs = []

    # 训练集和测试集的train_adj_info赋值,存储邻接节点的信息。只有run了节点,赋值才会生效。
    train_adj_info = tf.assign(adj_info, minibatch.adj)
    val_adj_info = tf.assign(adj_info, minibatch.test_adj)
    for epoch in range(FLAGS.epochs):
        minibatch.shuffle()

        iter = 0
        print('Epoch: %04d' % (epoch + 1))
        epoch_val_costs.append(0)
        while not minibatch.end():
            # Construct feed dictionary
            feed_dict, labels = minibatch.next_minibatch_feed_dict()
            feed_dict.update({
                placeholders['dropout']: FLAGS.dropout
            })  # 现在feed_dict里输出的节点有:dropout、batch_size、batch(样本点的集合)、labels

            # LOG
            # print("inputs1 shape", sess.run([model.shapelog],feed_dict=feed_dict))
            print("inputs:", sess.run([model.inputs1], feed_dict=feed_dict))
            print("samples0", sess.run([model.log0], feed_dict=feed_dict))
            print("samples1", sess.run([model.log1], feed_dict=feed_dict))
            print("samples2", sess.run([model.log2], feed_dict=feed_dict))

            t = time.time()
            # Training step
            outs = sess.run([merged, model.opt_op, model.loss, model.preds],
                            feed_dict=feed_dict)
            train_cost = outs[2]

            # log
            break

            if iter % FLAGS.validate_iter == 0:
                # Validation
                sess.run(val_adj_info.op)
                if FLAGS.validate_batch_size == -1:
                    val_cost, val_f1_mic, val_f1_mac, duration = incremental_evaluate(
                        sess, model, minibatch, FLAGS.batch_size)
                else:
                    val_cost, val_f1_mic, val_f1_mac, duration = evaluate(
                        sess, model, minibatch, FLAGS.validate_batch_size)
                sess.run(train_adj_info.op)
                epoch_val_costs[-1] += val_cost

            if total_steps % FLAGS.print_every == 0:
                summary_writer.add_summary(outs[0], total_steps)

            # Print results
            avg_time = (avg_time * total_steps + time.time() -
                        t) / (total_steps + 1)

            if total_steps % FLAGS.print_every == 0:
                train_f1_mic, train_f1_mac = calc_f1(labels, outs[-1])
                print("Iter:", '%04d' % iter, "train_loss=",
                      "{:.5f}".format(train_cost), "train_f1_mic=",
                      "{:.5f}".format(train_f1_mic), "train_f1_mac=",
                      "{:.5f}".format(train_f1_mac), "val_loss=",
                      "{:.5f}".format(val_cost), "val_f1_mic=",
                      "{:.5f}".format(val_f1_mic), "val_f1_mac=",
                      "{:.5f}".format(val_f1_mac), "time=",
                      "{:.5f}".format(avg_time))

            iter += 1
            total_steps += 1

            if total_steps > FLAGS.max_total_steps:
                break

        if total_steps > FLAGS.max_total_steps:
            break

    print("Optimization Finished!")
    sess.run(val_adj_info.op)
    val_cost, val_f1_mic, val_f1_mac, duration = incremental_evaluate(
        sess, model, minibatch, FLAGS.batch_size)
    print("Full validation stats:", "loss=", "{:.5f}".format(val_cost),
          "f1_micro=", "{:.5f}".format(val_f1_mic), "f1_macro=",
          "{:.5f}".format(val_f1_mac), "time=", "{:.5f}".format(duration))
    with open(log_dir() + "val_stats.txt", "w") as fp:
        fp.write(
            "loss={:.5f} f1_micro={:.5f} f1_macro={:.5f} time={:.5f}".format(
                val_cost, val_f1_mic, val_f1_mac, duration))

    print("Writing test set stats to file (don't peak!)")
    val_cost, val_f1_mic, val_f1_mac, duration = incremental_evaluate(
        sess, model, minibatch, FLAGS.batch_size, test=True)
    with open(log_dir() + "test_stats.txt", "w") as fp:
        fp.write("loss={:.5f} f1_micro={:.5f} f1_macro={:.5f}".format(
            val_cost, val_f1_mic, val_f1_mac))
Пример #15
0
def train(train_data, test_data=None):

    G = train_data[0]
    features = train_data[1]
    id_map = train_data[2]
    class_map = train_data[4]
    if isinstance(list(class_map.values())[0], list):
        num_classes = len(list(class_map.values())[0])
    else:
        num_classes = len(set(class_map.values()))

    if not features is None:
        # pad with dummy zero vector
        features = np.vstack([features, np.zeros((features.shape[1], ))])

    context_pairs = train_data[3] if FLAGS.random_context else None
    placeholders = construct_placeholders(num_classes)
    minibatch = NodeMinibatchIterator(G,
                                      id_map,
                                      placeholders,
                                      class_map,
                                      num_classes,
                                      batch_size=FLAGS.batch_size,
                                      max_degree=FLAGS.max_degree,
                                      context_pairs=context_pairs)
    adj_info_ph = tf.compat.v1.placeholder(tf.int32, shape=minibatch.adj.shape)
    adj_info = tf.Variable(adj_info_ph, trainable=False, name="adj_info")

    if FLAGS.model == 'graphsage_mean':
        # Create model
        sampler = UniformNeighborSampler(adj_info)
        if FLAGS.samples_3 != 0:
            layer_infos = [
                SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2),
                SAGEInfo("node", sampler, FLAGS.samples_3, FLAGS.dim_2)
            ]
        elif FLAGS.samples_2 != 0:
            layer_infos = [
                SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
            ]
        else:
            layer_infos = [
                SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1)
            ]

        model = SupervisedGraphsage(num_classes,
                                    placeholders,
                                    features,
                                    adj_info,
                                    minibatch.deg,
                                    layer_infos,
                                    model_size=FLAGS.model_size,
                                    sigmoid_loss=FLAGS.sigmoid,
                                    identity_dim=FLAGS.identity_dim,
                                    logging=True)
    elif FLAGS.model == 'gcn':
        # Create model
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, 2 * FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, 2 * FLAGS.dim_2)
        ]

        model = SupervisedGraphsage(num_classes,
                                    placeholders,
                                    features,
                                    adj_info,
                                    minibatch.deg,
                                    layer_infos=layer_infos,
                                    aggregator_type="gcn",
                                    model_size=FLAGS.model_size,
                                    concat=False,
                                    sigmoid_loss=FLAGS.sigmoid,
                                    identity_dim=FLAGS.identity_dim,
                                    logging=True)

    elif FLAGS.model == 'graphsage_seq':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SupervisedGraphsage(num_classes,
                                    placeholders,
                                    features,
                                    adj_info,
                                    minibatch.deg,
                                    layer_infos=layer_infos,
                                    aggregator_type="seq",
                                    model_size=FLAGS.model_size,
                                    sigmoid_loss=FLAGS.sigmoid,
                                    identity_dim=FLAGS.identity_dim,
                                    logging=True)

    elif FLAGS.model == 'graphsage_maxpool':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SupervisedGraphsage(num_classes,
                                    placeholders,
                                    features,
                                    adj_info,
                                    minibatch.deg,
                                    layer_infos=layer_infos,
                                    aggregator_type="maxpool",
                                    model_size=FLAGS.model_size,
                                    sigmoid_loss=FLAGS.sigmoid,
                                    identity_dim=FLAGS.identity_dim,
                                    logging=True)

    elif FLAGS.model == 'graphsage_meanpool':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SupervisedGraphsage(num_classes,
                                    placeholders,
                                    features,
                                    adj_info,
                                    minibatch.deg,
                                    layer_infos=layer_infos,
                                    aggregator_type="meanpool",
                                    model_size=FLAGS.model_size,
                                    sigmoid_loss=FLAGS.sigmoid,
                                    identity_dim=FLAGS.identity_dim,
                                    logging=True)

    else:
        raise Exception('Error: model name unrecognized.')

    config = tf.compat.v1.ConfigProto(
        log_device_placement=FLAGS.log_device_placement)
    config.gpu_options.allow_growth = True
    #config.gpu_options.per_process_gpu_memory_fraction = GPU_MEM_FRACTION
    config.allow_soft_placement = True

    # Initialize WandB experiment
    wandb.init(project='chengdu_GraphSAGE',
               save_code=True,
               tags=['supervised'])
    wandb.config.update(flags.FLAGS)

    # Initialize session
    sess = tf.compat.v1.Session(config=config)
    merged = tf.compat.v1.summary.merge_all()
    summary_writer = tf.compat.v1.summary.FileWriter(log_dir(), sess.graph)

    # Init variables
    sess.run(tf.compat.v1.global_variables_initializer(),
             feed_dict={adj_info_ph: minibatch.adj})

    # Init saver
    saver = tf.compat.v1.train.Saver(max_to_keep=8,
                                     keep_checkpoint_every_n_hours=1)

    # Train model
    total_steps = 0
    avg_time = 0.0
    epoch_val_costs = []

    train_adj_info = tf.compat.v1.assign(adj_info, minibatch.adj)
    val_adj_info = tf.compat.v1.assign(adj_info, minibatch.test_adj)
    for epoch in range(FLAGS.epochs):
        minibatch.shuffle()

        iter = 0
        print('Epoch: %04d' % (epoch + 1))
        epoch_val_costs.append(0)
        while not minibatch.end():
            # Construct feed dictionary
            feed_dict, labels = minibatch.next_minibatch_feed_dict()
            feed_dict.update({placeholders['dropout']: FLAGS.dropout})

            t = time.time()
            # Training step
            outs = sess.run([merged, model.opt_op, model.loss, model.preds],
                            feed_dict=feed_dict)
            train_cost = outs[2]

            # Validation
            if iter % FLAGS.validate_iter == 0:
                sess.run(val_adj_info.op)
                if FLAGS.validate_batch_size == -1:
                    val_cost, val_f1_mic, val_f1_mac, duration = incremental_evaluate(
                        sess, model, minibatch, FLAGS.batch_size)
                else:
                    val_cost, val_f1_mic, val_f1_mac, duration = evaluate(
                        sess, model, minibatch, FLAGS.validate_batch_size)
                sess.run(train_adj_info.op)
                epoch_val_costs[-1] += val_cost

            if total_steps % FLAGS.print_every == 0:
                summary_writer.add_summary(outs[0], total_steps)

            # Print results
            avg_time = (avg_time * total_steps + time.time() -
                        t) / (total_steps + 1)

            if total_steps % FLAGS.print_every == 0:
                train_f1_mic, train_f1_mac = calc_f1(labels, outs[-1])
                print("[%03d/%03d]" % (epoch + 1, FLAGS.epochs), "Iter:",
                      '%04d' % iter, "train_loss =",
                      "{:.5f}".format(train_cost), "train_f1_mic =",
                      "{:.5f}".format(train_f1_mic), "train_f1_mac =",
                      "{:.5f}".format(train_f1_mac), "val_loss =",
                      "{:.5f}".format(val_cost), "val_f1_mic =",
                      "{:.5f}".format(val_f1_mic), "val_f1_mac =",
                      "{:.5f}".format(val_f1_mac), "time =",
                      "{:.5f}".format(avg_time))

            # W&B Logging
            if FLAGS.wandb_log and iter % FLAGS.wandb_log_iter == 0:
                train_f1_mic, train_f1_mac = calc_f1(labels, outs[-1])
                wandb.log({'train_loss': train_cost, 'epoch': epoch})
                wandb.log({'train_f1_mic': train_f1_mic, 'epoch': epoch})
                wandb.log({'train_f1_mac': train_f1_mac, 'epoch': epoch})
                wandb.log({'val_cost': val_cost, 'epoch': epoch})
                wandb.log({'val_f1_mic': val_f1_mic, 'epoch': epoch})
                wandb.log({'val_f1_mac': val_f1_mac, 'epoch': epoch})
                wandb.log({'time': avg_time, 'epoch': epoch})

            iter += 1
            total_steps += 1

            if total_steps > FLAGS.max_total_steps:
                break

        # Save Model checkpoints
        if FLAGS.save_checkpoints and epoch % FLAGS.save_checkpoints_epoch == 0:
            # saver.save(sess, log_dir() + 'model', global_step=1000)
            print('Save model checkpoint:', wandb.run.dir, iter, total_steps,
                  epoch)
            saver.save(
                sess,
                os.path.join(wandb.run.dir,
                             "model-" + str(epoch + 1) + ".ckpt"))

        if total_steps > FLAGS.max_total_steps:
            break

    print("Optimization Finished!")
    sess.run(val_adj_info.op)
    val_cost, val_f1_mic, val_f1_mac, duration = incremental_evaluate(
        sess, model, minibatch, FLAGS.batch_size)
    print("Full validation stats:", "loss=", "{:.5f}".format(val_cost),
          "f1_micro=", "{:.5f}".format(val_f1_mic), "f1_macro=",
          "{:.5f}".format(val_f1_mac), "time=", "{:.5f}".format(duration))
    with open(log_dir() + "val_stats.txt", "w") as fp:
        fp.write(
            "loss={:.5f} f1_micro={:.5f} f1_macro={:.5f} time={:.5f}".format(
                val_cost, val_f1_mic, val_f1_mac, duration))

    print("Writing test set stats to file (don't peak!)")
    val_cost, val_f1_mic, val_f1_mac, duration = incremental_evaluate(
        sess, model, minibatch, FLAGS.batch_size, test=True)
    with open(log_dir() + "test_stats.txt", "w") as fp:
        fp.write("loss={:.5f} f1_micro={:.5f} f1_macro={:.5f}".format(
            val_cost, val_f1_mic, val_f1_mac))
Пример #16
0
def train(train_data, action, test_data=None, regress_fun=run_regression):
    config = tf.ConfigProto(log_device_placement=FLAGS.log_device_placement)
    config.gpu_options.allow_growth = True
    # config.gpu_options.per_process_gpu_memory_fraction = GPU_MEM_FRACTION
    config.allow_soft_placement = True
    # Initialize session
    # 定义作用域,不然会与Controller发生冲突
    with tf.Session(config=config, graph=tf.Graph()) as sess:
        with sess.as_default():
            with sess.graph.as_default():
                # Set random seed
                seed = 123
                np.random.seed(seed)
                tf.set_random_seed(seed)

                G = train_data[0]
                features = train_data[1]
                id_map = train_data[2]

                if not features is None:
                    # pad with dummy zero vector
                    features = np.vstack(
                        [features, np.zeros((features.shape[1], ))])

                context_pairs = train_data[3] if FLAGS.random_context else None
                placeholders = construct_placeholders()
                minibatch = EdgeMinibatchIterator(
                    G,
                    id_map,
                    placeholders,
                    batch_size=FLAGS.batch_size,
                    max_degree=FLAGS.max_degree,
                    num_neg_samples=FLAGS.neg_sample_size,
                    context_pairs=context_pairs)
                adj_info_ph = tf.placeholder(tf.int32,
                                             shape=minibatch.adj.shape)
                adj_info = tf.Variable(adj_info_ph,
                                       trainable=False,
                                       name="adj_info")

                sampler = UniformNeighborSampler(adj_info)  # 邻居采样,方式为随机重排邻居
                state_nums = 2  # Controller定义的状态数量
                layers_num = len(action) // state_nums  # 计算层数
                layer_infos = []
                # 用于指导最终GNN的生成
                layer_infos.append(
                    SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1))
                layer_infos.append(
                    SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_1))
                # 用于NAS的无监督GraphSage
                model = NASUnsupervisedGraphsage(
                    placeholders,
                    features,
                    adj_info,
                    minibatch.deg,
                    layer_infos=layer_infos,
                    action=action,
                    state_nums=state_nums,
                    model_size=FLAGS.model_size,
                    identity_dim=FLAGS.identity_dim,
                    logging=True)

                merged = tf.summary.merge_all()
                summary_writer = tf.summary.FileWriter(log_dir(action),
                                                       sess.graph)

                # Init variables
                sess.run(tf.global_variables_initializer(),
                         feed_dict={adj_info_ph: minibatch.adj})

                # Train model

                train_shadow_mrr = None
                shadow_mrr = None

                total_steps = 0
                avg_time = 0.0
                epoch_val_costs = []

                train_adj_info = tf.assign(adj_info, minibatch.adj)
                val_adj_info = tf.assign(adj_info, minibatch.test_adj)
                for epoch in range(FLAGS.epochs):
                    minibatch.shuffle()

                    iter = 0
                    print('Epoch: %04d' % (epoch + 1))
                    epoch_val_costs.append(0)
                    while not minibatch.end():
                        # Construct feed dictionary
                        feed_dict = minibatch.next_minibatch_feed_dict()
                        feed_dict.update(
                            {placeholders['dropout']: FLAGS.dropout})

                        t = time.time()
                        # Training step
                        outs = sess.run([
                            merged, model.opt_op, model.loss, model.ranks,
                            model.aff_all, model.mrr, model.outputs1
                        ],
                                        feed_dict=feed_dict)
                        train_cost = outs[2]
                        train_mrr = outs[5]
                        if train_shadow_mrr is None:
                            train_shadow_mrr = train_mrr  #
                        else:
                            train_shadow_mrr -= (1 - 0.99) * (
                                train_shadow_mrr - train_mrr)

                        if iter % FLAGS.validate_iter == 0:
                            # Validation
                            sess.run(val_adj_info.op)
                            val_cost, ranks, val_mrr, duration = evaluate(
                                sess,
                                model,
                                minibatch,
                                size=FLAGS.validate_batch_size)
                            sess.run(train_adj_info.op)
                            epoch_val_costs[-1] += val_cost
                        if shadow_mrr is None:
                            shadow_mrr = val_mrr
                        else:
                            shadow_mrr -= (1 - 0.99) * (shadow_mrr - val_mrr)

                        if total_steps % FLAGS.print_every == 0:
                            summary_writer.add_summary(outs[0], total_steps)

                        # Print results
                        avg_time = (avg_time * total_steps + time.time() -
                                    t) / (total_steps + 1)

                        if total_steps % FLAGS.print_every == 0:
                            print(
                                "Iter:",
                                '%04d' % iter,
                                "train_loss=",
                                "{:.5f}".format(train_cost),
                                "train_mrr=",
                                "{:.5f}".format(train_mrr),
                                "train_mrr_ema=",
                                "{:.5f}".format(
                                    train_shadow_mrr
                                ),  # exponential moving average
                                "val_loss=",
                                "{:.5f}".format(val_cost),
                                "val_mrr=",
                                "{:.5f}".format(val_mrr),
                                "val_mrr_ema=",
                                "{:.5f}".format(
                                    shadow_mrr),  # exponential moving average
                                "time=",
                                "{:.5f}".format(avg_time))

                        iter += 1
                        total_steps += 1

                        if total_steps > FLAGS.max_total_steps:
                            break

                    if total_steps > FLAGS.max_total_steps:
                        break

                print("Optimization Finished!")
                sess.run(val_adj_info.op)
                # val_losess,val_mrr, duration = incremental_evaluate(sess, model, minibatch,
                #                                                  FLAGS.batch_size)
                # print("Full validation stats:",
                #       "loss=", "{:.5f}".format(val_losess),
                #       "val_mrr=", "{:.5f}".format(val_mrr),
                #       "time=", "{:.5f}".format(duration))

                if FLAGS.save_embeddings:
                    sess.run(val_adj_info.op)

                    embeds = save_val_embeddings(sess, model, minibatch,
                                                 FLAGS.validate_batch_size,
                                                 log_dir(action))

                    if FLAGS.model == "n2v":
                        # stopping the gradient for the already trained nodes
                        train_ids = tf.constant(
                            [[id_map[n]] for n in G.nodes_iter()
                             if not G.node[n]['val'] and not G.node[n]['test']
                             ],
                            dtype=tf.int32)
                        test_ids = tf.constant(
                            [[id_map[n]] for n in G.nodes_iter()
                             if G.node[n]['val'] or G.node[n]['test']],
                            dtype=tf.int32)
                        update_nodes = tf.nn.embedding_lookup(
                            model.context_embeds, tf.squeeze(test_ids))
                        no_update_nodes = tf.nn.embedding_lookup(
                            model.context_embeds, tf.squeeze(train_ids))
                        update_nodes = tf.scatter_nd(
                            test_ids, update_nodes,
                            tf.shape(model.context_embeds))
                        no_update_nodes = tf.stop_gradient(
                            tf.scatter_nd(train_ids, no_update_nodes,
                                          tf.shape(model.context_embeds)))
                        model.context_embeds = update_nodes + no_update_nodes
                        sess.run(model.context_embeds)

                        # run random walks
                        from graphsage.utils import run_random_walks
                        nodes = [
                            n for n in G.nodes_iter()
                            if G.node[n]["val"] or G.node[n]["test"]
                        ]
                        start_time = time.time()
                        pairs = run_random_walks(G, nodes, num_walks=50)
                        walk_time = time.time() - start_time

                        test_minibatch = EdgeMinibatchIterator(
                            G,
                            id_map,
                            placeholders,
                            batch_size=FLAGS.batch_size,
                            max_degree=FLAGS.max_degree,
                            num_neg_samples=FLAGS.neg_sample_size,
                            context_pairs=pairs,
                            n2v_retrain=True,
                            fixed_n2v=True)

                        start_time = time.time()
                        print("Doing test training for n2v.")
                        test_steps = 0
                        for epoch in range(FLAGS.n2v_test_epochs):
                            test_minibatch.shuffle()
                            while not test_minibatch.end():
                                feed_dict = test_minibatch.next_minibatch_feed_dict(
                                )
                                feed_dict.update(
                                    {placeholders['dropout']: FLAGS.dropout})
                                outs = sess.run([
                                    model.opt_op, model.loss, model.ranks,
                                    model.aff_all, model.mrr, model.outputs1
                                ],
                                                feed_dict=feed_dict)
                                if test_steps % FLAGS.print_every == 0:
                                    print("Iter:", '%04d' % test_steps,
                                          "train_loss=",
                                          "{:.5f}".format(outs[1]),
                                          "train_mrr=",
                                          "{:.5f}".format(outs[-2]))
                                test_steps += 1
                        train_time = time.time() - start_time
                        save_val_embeddings(sess,
                                            model,
                                            minibatch,
                                            FLAGS.validate_batch_size,
                                            log_dir(action),
                                            mod="-test")
                        print("Total time: ", train_time + walk_time)
                        print("Walk time: ", walk_time)
                        print("Train time: ", train_time)
    setting = "val"
    tf.reset_default_graph()
    labels = train_data[4]
    train_ids = [
        n for n in G.nodes() if not G.node[n]['val'] and not G.node[n]['test']
    ]
    test_ids = [n for n in G.nodes() if G.node[n][setting]]
    train_labels = np.array([labels[i] for i in train_ids])
    if train_labels.ndim == 1:
        train_labels = np.expand_dims(train_labels, 1)
    test_labels = np.array([labels[i] for i in test_ids])
    id_map = {}
    with open(log_dir(action) + "/val.txt") as fp:
        for i, line in enumerate(fp):
            id_map[line.strip()] = i
    train_embeds = embeds[[id_map[str(id)] for id in train_ids]]
    test_embeds = embeds[[id_map[str(id)] for id in test_ids]]

    print("Running regression..")
    regress_fun(train_embeds, train_labels, test_embeds, test_labels)
    #用f1指数替换accuracy,此处未做滑动指数平均
    return get_rewards(val_mrr), val_mrr
Пример #17
0
def train(train_data, test_data=None):

    G = train_data[0]
    features = train_data[1]

    if not features is None:
        # pad with dummy zero vector
        features = np.vstack([features, np.zeros((features.shape[1], ))])

    placeholders = construct_placeholders()
    minibatch = NodeMinibatchIterator(G,
                                      placeholders,
                                      batch_size=FLAGS.batch_size,
                                      max_degree=FLAGS.max_degree)
    adj_info_ph = tf.placeholder(tf.int32, shape=minibatch.adj.shape)
    adj_info = tf.Variable(adj_info_ph, trainable=False, name="adj_info")

    if FLAGS.model == 'graphsage_mean':
        # Create model
        sampler = UniformNeighborSampler(adj_info)
        if FLAGS.samples_3 != 0:
            layer_infos = [
                SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2),
                SAGEInfo("node", sampler, FLAGS.samples_3, FLAGS.dim_2)
            ]
        elif FLAGS.samples_2 != 0:
            layer_infos = [
                SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
            ]
        else:
            layer_infos = [
                SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1)
            ]

        model = SupervisedGraphsage(placeholders,
                                    features,
                                    adj_info,
                                    minibatch.deg,
                                    layer_infos,
                                    model_size=FLAGS.model_size,
                                    identity_dim=FLAGS.identity_dim,
                                    logging=True)
    elif FLAGS.model == 'gcn':
        # Create model
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, 2 * FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, 2 * FLAGS.dim_2)
        ]

        model = SupervisedGraphsage(placeholders,
                                    features,
                                    adj_info,
                                    minibatch.deg,
                                    layer_infos=layer_infos,
                                    aggregator_type="gcn",
                                    model_size=FLAGS.model_size,
                                    concat=False,
                                    identity_dim=FLAGS.identity_dim,
                                    logging=True)

    elif FLAGS.model == 'graphsage_seq':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SupervisedGraphsage(placeholders,
                                    features,
                                    adj_info,
                                    minibatch.deg,
                                    layer_infos=layer_infos,
                                    aggregator_type="seq",
                                    model_size=FLAGS.model_size,
                                    identity_dim=FLAGS.identity_dim,
                                    logging=True)

    elif FLAGS.model == 'graphsage_maxpool':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SupervisedGraphsage(placeholders,
                                    features,
                                    adj_info,
                                    minibatch.deg,
                                    layer_infos=layer_infos,
                                    aggregator_type="maxpool",
                                    model_size=FLAGS.model_size,
                                    identity_dim=FLAGS.identity_dim,
                                    logging=True)

    elif FLAGS.model == 'graphsage_meanpool':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SupervisedGraphsage(placeholders,
                                    features,
                                    adj_info,
                                    minibatch.deg,
                                    layer_infos=layer_infos,
                                    aggregator_type="meanpool",
                                    model_size=FLAGS.model_size,
                                    identity_dim=FLAGS.identity_dim,
                                    logging=True)

    else:
        raise Exception('Error: model name unrecognized.')

    config = tf.ConfigProto(log_device_placement=FLAGS.log_device_placement)
    config.gpu_options.allow_growth = True
    #config.gpu_options.per_process_gpu_memory_fraction = GPU_MEM_FRACTION
    config.allow_soft_placement = True

    # Initialize session
    sess = tf.Session(config=config)
    merged = tf.summary.merge_all()
    summary_writer = tf.summary.FileWriter(log_dir(), sess.graph)

    # Init variables
    sess.run(tf.global_variables_initializer(),
             feed_dict={adj_info_ph: minibatch.adj})

    # Train model

    total_steps = 0
    avg_time = 0.0
    epoch_val_costs = []

    for epoch in range(FLAGS.epochs):
        minibatch.shuffle()

        iter = 0
        print('Epoch: %04d' % (epoch + 1))
        epoch_val_costs.append(0)

        # Construct feed dictionary
        feed_dict = minibatch.next_minibatch_feed_dict()
        feed_dict.update({placeholders['dropout']: FLAGS.dropout})

        t = time.time()
        # Training step
        outs = sess.run([model.node_preds], feed_dict=feed_dict)
        print(outs[0].shape)
        #if total_steps % FLAGS.print_every == 0:
        #   summary_writer.add_summary(outs[0], total_steps)

        # Print results
        avg_time = (avg_time * total_steps + time.time() - t) / (total_steps +
                                                                 1)

        iter += 1
        total_steps += 1

        if total_steps > FLAGS.max_total_steps:
            break

    print("Optimization Finished!")
Пример #18
0
def train(train_data, action, test_data=None):
    config = tf.ConfigProto(log_device_placement=FLAGS.log_device_placement)
    config.gpu_options.allow_growth = True
    # config.gpu_options.per_process_gpu_memory_fraction = GPU_MEM_FRACTION
    config.allow_soft_placement = True
    # Initialize session
    # 定义作用域,不然会与Controller发生冲突
    with tf.Session(config=config, graph=tf.Graph()) as sess:
        with sess.as_default():
            with sess.graph.as_default():
                # Set random seed
                seed = 123
                np.random.seed(seed)
                tf.set_random_seed(seed)

                G = train_data[0]  # 图数据
                features = train_data[1]  #节点特征值
                id_map = train_data[2]  #节点id对index的映射
                class_map = train_data[4]  # 节点类别
                #计算类别数量
                if isinstance(list(class_map.values())[0], list):
                    num_classes = len(list(class_map.values())[0])
                else:
                    num_classes = len(set(class_map.values()))
                # 添加一个全0的数据,不知道用途
                if not features is None:
                    # pad with dummy zero vector
                    features = np.vstack(
                        [features, np.zeros((features.shape[1], ))])
                # 随机游走生成共现边,用来替换图中边信息
                context_pairs = train_data[3] if FLAGS.random_context else None
                placeholders = construct_placeholders(num_classes)
                minibatch = NodeMinibatchIterator(
                    G,
                    id_map,
                    placeholders,
                    class_map,
                    num_classes,
                    batch_size=FLAGS.batch_size,
                    max_degree=FLAGS.max_degree,
                    context_pairs=context_pairs)  # 用于批处理
                adj_info_ph = tf.placeholder(tf.int32,
                                             shape=minibatch.adj.shape)
                adj_info = tf.Variable(adj_info_ph,
                                       trainable=False,
                                       name="adj_info")
                # 创建模型
                sampler = UniformNeighborSampler(adj_info)  # 邻居采样,方式为随机重排邻居
                state_nums = 2  # Controller定义的状态数量
                layers_num = len(action) // state_nums  #计算层数
                layer_infos = []
                # 用于指导最终GNN的生层,这里只修改了采样数量
                # for i in range(layers_num):
                layer_infos.append(
                    SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1))
                layer_infos.append(
                    SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_1))
                # 用于NAS的监督GraphSage
                model = NASSupervisedGraphsage(num_classes,
                                               placeholders,
                                               features,
                                               adj_info,
                                               minibatch.deg,
                                               layer_infos,
                                               state_nums=state_nums,
                                               action=action,
                                               model_size=FLAGS.model_size,
                                               sigmoid_loss=FLAGS.sigmoid,
                                               identity_dim=FLAGS.identity_dim,
                                               logging=True)

                # 记录
                merged = tf.summary.merge_all()
                summary_writer = tf.summary.FileWriter(log_dir(action),
                                                       sess.graph)

                # Init variables
                sess.run(tf.global_variables_initializer(),
                         feed_dict={adj_info_ph: minibatch.adj})

                # Train model

                total_steps = 0
                avg_time = 0.0
                epoch_val_costs = []

                train_adj_info = tf.assign(adj_info, minibatch.adj)
                val_adj_info = tf.assign(adj_info, minibatch.test_adj)
                for epoch in range(FLAGS.epochs):
                    minibatch.shuffle()

                    iter = 0
                    print('Epoch: %04d' % (epoch + 1))
                    epoch_val_costs.append(0)
                    while not minibatch.end():
                        # Construct feed dictionary
                        feed_dict, labels = minibatch.next_minibatch_feed_dict(
                        )
                        feed_dict.update(
                            {placeholders['dropout']: FLAGS.dropout})

                        t = time.time()
                        # Training step
                        outs = sess.run(
                            [merged, model.opt_op, model.loss, model.preds],
                            feed_dict=feed_dict)
                        train_cost = outs[2]

                        if iter % FLAGS.validate_iter == 0:
                            # Validation
                            sess.run(val_adj_info.op)
                            if FLAGS.validate_batch_size == -1:
                                val_cost, val_f1_mic, val_f1_mac, duration = incremental_evaluate(
                                    sess, model, minibatch, FLAGS.batch_size)
                            else:
                                val_cost, val_f1_mic, val_f1_mac, duration = evaluate(
                                    sess, model, minibatch,
                                    FLAGS.validate_batch_size)
                            sess.run(train_adj_info.op)
                            epoch_val_costs[-1] += val_cost

                        if total_steps % FLAGS.print_every == 0:
                            summary_writer.add_summary(outs[0], total_steps)

                        # Print results
                        avg_time = (avg_time * total_steps + time.time() -
                                    t) / (total_steps + 1)

                        if total_steps % FLAGS.print_every == 0:
                            train_f1_mic, train_f1_mac = calc_f1(
                                labels, outs[-1])
                            print("Iter:", '%04d' % iter, "train_loss=",
                                  "{:.5f}".format(train_cost), "train_f1_mic=",
                                  "{:.5f}".format(train_f1_mic),
                                  "train_f1_mac=",
                                  "{:.5f}".format(train_f1_mac), "val_loss=",
                                  "{:.5f}".format(val_cost), "val_f1_mic=",
                                  "{:.5f}".format(val_f1_mic), "val_f1_mac=",
                                  "{:.5f}".format(val_f1_mac), "time=",
                                  "{:.5f}".format(avg_time))

                        iter += 1
                        total_steps += 1

                        if total_steps > FLAGS.max_total_steps:
                            break

                    if total_steps > FLAGS.max_total_steps:
                        break

                print("Optimization Finished!")
                sess.run(val_adj_info.op)
                # 分批(增量式)验证
                val_cost, val_f1_mic, val_f1_mac, duration = incremental_evaluate(
                    sess, model, minibatch, FLAGS.batch_size)
                print("Full validation stats:", "loss=",
                      "{:.5f}".format(val_cost), "f1_micro=",
                      "{:.5f}".format(val_f1_mic), "f1_macro=",
                      "{:.5f}".format(val_f1_mac), "time=",
                      "{:.5f}".format(duration))
                with open(log_dir(action) + "val_stats.txt", "w") as fp:
                    fp.write(
                        "loss={:.5f} f1_micro={:.5f} f1_macro={:.5f} time={:.5f}"
                        .format(val_cost, val_f1_mic, val_f1_mac, duration))

                print("Writing test set stats to file (don't peak!)")
                val_cost, val_f1_mic, val_f1_mac, duration = incremental_evaluate(
                    sess, model, minibatch, FLAGS.batch_size, test=True)
                with open(log_dir(action) + "test_stats.txt", "w") as fp:
                    fp.write(
                        "loss={:.5f} f1_micro={:.5f} f1_macro={:.5f}".format(
                            val_cost, val_f1_mic, val_f1_mac))
    tf.reset_default_graph()
    #用f1指数替换accuracy,此处未做滑动指数平均
    return get_rewards(val_f1_mic), val_f1_mic
Пример #19
0
def train(train_data, test_data=None):
    G = train_data[0]
    features_np = train_data[1]
    id_map = train_data[2]
    train_nodes = train_data[3]
    class_map = train_data[4]

    num_classes = class_map.shape[1]

    if not features_np is None:
        # pad with dummy zero vector
        features_np = np.vstack([features_np, np.zeros((features_np.shape[1],))])

    context_pairs = train_data[3] if FLAGS.random_context else None
    placeholders = construct_placeholders(num_classes)
    minibatch = NodeMinibatchIterator(G,
                                      id_map,
                                      placeholders,
                                      class_map,
                                      num_classes,
                                      train_nodes,
                                      batch_size=FLAGS.batch_size,
                                      max_degree=FLAGS.max_degree,
                                      context_pairs=context_pairs)
    adj_info_ph = tf.placeholder(tf.int32, shape=minibatch.adj.shape)
    adj_info = tf.Variable(adj_info_ph, trainable=False, name="adj_info")

    features_ph = tf.placeholder(tf.float32, shape=features_np.shape)
    features = tf.Variable(features_ph, trainable=False, name="features")

    if FLAGS.model == 'graphsage_mean':
        # Create model
        sampler = UniformNeighborSampler(adj_info)
        if FLAGS.samples_3 != 0:
            layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                           SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2),
                           SAGEInfo("node", sampler, FLAGS.samples_3, FLAGS.dim_2)]
        elif FLAGS.samples_2 != 0:
            layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                           SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)]
        else:
            layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1)]

        model = SupervisedGraphsage(num_classes, placeholders,
                                    features,
                                    adj_info,
                                    minibatch.deg,
                                    layer_infos,
                                    model_size=FLAGS.model_size,
                                    sigmoid_loss=FLAGS.sigmoid,
                                    identity_dim=FLAGS.identity_dim,
                                    logging=True)
    elif FLAGS.model == 'gcn':
        # Create model
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, 2 * FLAGS.dim_1),
                       SAGEInfo("node", sampler, FLAGS.samples_2, 2 * FLAGS.dim_2)]

        model = SupervisedGraphsage(num_classes, placeholders,
                                    features,
                                    adj_info,
                                    minibatch.deg,
                                    layer_infos=layer_infos,
                                    aggregator_type="gcn",
                                    model_size=FLAGS.model_size,
                                    concat=False,
                                    sigmoid_loss=FLAGS.sigmoid,
                                    identity_dim=FLAGS.identity_dim,
                                    logging=True)

    elif FLAGS.model == 'graphsage_seq':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                       SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)]

        model = SupervisedGraphsage(num_classes, placeholders,
                                    features,
                                    adj_info,
                                    minibatch.deg,
                                    layer_infos=layer_infos,
                                    aggregator_type="seq",
                                    model_size=FLAGS.model_size,
                                    sigmoid_loss=FLAGS.sigmoid,
                                    identity_dim=FLAGS.identity_dim,
                                    logging=True)

    elif FLAGS.model == 'graphsage_maxpool':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                       SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)]

        model = SupervisedGraphsage(num_classes, placeholders,
                                    features,
                                    adj_info,
                                    minibatch.deg,
                                    layer_infos=layer_infos,
                                    aggregator_type="maxpool",
                                    model_size=FLAGS.model_size,
                                    sigmoid_loss=FLAGS.sigmoid,
                                    identity_dim=FLAGS.identity_dim,
                                    logging=True)

    elif FLAGS.model == 'graphsage_meanpool':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                       SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)]

        model = SupervisedGraphsage(num_classes, placeholders,
                                    features,
                                    adj_info,
                                    minibatch.deg,
                                    layer_infos=layer_infos,
                                    aggregator_type="meanpool",
                                    model_size=FLAGS.model_size,
                                    sigmoid_loss=FLAGS.sigmoid,
                                    identity_dim=FLAGS.identity_dim,
                                    logging=True)

    else:
        raise Exception('Error: model name unrecognized.')

    config = tf.ConfigProto(log_device_placement=FLAGS.log_device_placement)
    config.gpu_options.allow_growth = True
    # config.gpu_options.per_process_gpu_memory_fraction = GPU_MEM_FRACTION
    config.allow_soft_placement = True

    # Initialize session
    sess = tf.Session(config=config)
    merged = tf.summary.merge_all()
    summary_writer = tf.summary.FileWriter(log_dir(), sess.graph)

    # Init variables
    sess.run(tf.global_variables_initializer(), feed_dict={adj_info_ph: minibatch.adj,
                                                           features_ph: features_np})

    # Train model

    total_steps = 0
    avg_time = 0.0
    epoch_val_costs = []
    best = 0

    up_adj_info = tf.assign(adj_info, adj_info_ph, name='up_adj')
    up_features = tf.assign(features, features_ph, name='up_features')

    for epoch in range(FLAGS.epochs):
        minibatch.shuffle()

        iter = 0
        print('Epoch: %04d' % (epoch + 1))
        epoch_val_costs.append(0)
        while not minibatch.end():
            # Construct feed dictionary
            feed_dict, labels = minibatch.next_minibatch_feed_dict()
            feed_dict.update({placeholders['dropout']: FLAGS.dropout})

            t = time.time()
            # Training step
            outs = sess.run([merged, model.opt_op, model.loss, model.preds], feed_dict=feed_dict)
            train_cost = outs[2]

            if iter % FLAGS.validate_iter == 0:
                # Validation
                sess.run(up_adj_info.op, feed_dict={adj_info_ph: minibatch.test_adj})
                # sess.run([adj_info], feed_dict={adj_info_ph: minibatch.test_adj})

                if FLAGS.validate_batch_size == -1:
                    val_cost, val_f1_mic, val_f1_mac, duration = incremental_evaluate(sess, model, minibatch,
                                                                                      FLAGS.batch_size)
                else:
                    val_cost, val_f1_mic, val_f1_mac, duration = evaluate(sess, model, minibatch,
                                                                          FLAGS.validate_batch_size)

                if val_f1_mic > best:
                    print("Saving best model")
                    shutil.rmtree(log_dir() + 'saved_model_best', ignore_errors=True)

                    tf.saved_model.simple_save(
                        sess, log_dir() + 'saved_model_best',
                        {'nodes': placeholders['batch'],
                         'batch_size': placeholders['batch_size'],
                         # 'adjacency': adj_info_ph,
                         # 'features': features_ph
                         },
                        {'embeddings': model.outputs1}
                    )
                    best = val_f1_mic

                # sess.run([adj_info], feed_dict={adj_info_ph: minibatch.adj})
                sess.run(up_adj_info.op, feed_dict={adj_info_ph: minibatch.adj})
                epoch_val_costs[-1] += val_cost

            if total_steps % FLAGS.print_every == 0:
                summary_writer.add_summary(outs[0], total_steps)

            # Print results
            avg_time = (avg_time * total_steps + time.time() - t) / (total_steps + 1)

            if total_steps % FLAGS.print_every == 0:
                train_f1_mic, train_f1_mac = calc_f1(labels, outs[-1])
                print("Iter:", '%04d' % iter,
                      "train_loss=", "{:.5f}".format(train_cost),
                      "train_f1_mic=", "{:.5f}".format(train_f1_mic),
                      "train_f1_mac=", "{:.5f}".format(train_f1_mac),
                      "val_loss=", "{:.5f}".format(val_cost),
                      "val_f1_mic=", "{:.5f}".format(val_f1_mic),
                      "val_f1_mac=", "{:.5f}".format(val_f1_mac),
                      "time=", "{:.5f}".format(avg_time))

            iter += 1
            total_steps += 1

            if total_steps > FLAGS.max_total_steps:
                break

        if total_steps > FLAGS.max_total_steps:
            break

    print("Optimization Finished! Best save model", best)

    sess.run(up_adj_info.op, feed_dict={adj_info_ph: minibatch.test_adj})
    sess.run(up_features.op, feed_dict={features_ph: features_np})

    tf.saved_model.simple_save(
        sess, log_dir() + '/saved_model',
        {'nodes': placeholders['batch'],
         'batch_size': placeholders['batch_size'],
         # 'adjacency': adj_info_ph,
         # 'features': features_ph
         },
        {'embeddings': model.outputs1}
    )
    # sess.run(val_adj_info.op)

    val_cost, val_f1_mic, val_f1_mac, duration = incremental_evaluate(sess, model, minibatch, FLAGS.batch_size)
    print("Full validation stats:",
          "loss=", "{:.5f}".format(val_cost),
          "f1_micro=", "{:.5f}".format(val_f1_mic),
          "f1_macro=", "{:.5f}".format(val_f1_mac),
          "time=", "{:.5f}".format(duration))
    with open(log_dir() + "val_stats.txt", "w") as fp:
        fp.write("loss={:.5f} f1_micro={:.5f} f1_macro={:.5f} time={:.5f}".
                 format(val_cost, val_f1_mic, val_f1_mac, duration))

    print("Writing test set stats to file (don't peak!)")
    val_cost, val_f1_mic, val_f1_mac, duration = incremental_evaluate(sess, model, minibatch, FLAGS.batch_size,
                                                                      test=True)
    with open(log_dir() + "test_stats.txt", "w") as fp:
        fp.write("loss={:.5f} f1_micro={:.5f} f1_macro={:.5f} best={:.5f}".
                 format(val_cost, val_f1_mic, val_f1_mac, best))
Пример #20
0
def train(train_data, test_data, bs=None):
    G = train_data
    G_test = test_data

    if not G.features is None:
        # pad with dummy zero vector
        G.features = np.vstack([G.features, np.zeros((G.features.shape[1], ))])
    features = G.features
    # context_pairs = train_data[3] if FLAGS.random_context else None

    placeholders = construct_placeholders()
    minibatch = EdgeMinibatchIterator(
        G,
        G_test,
        bs,
        # id_map,
        placeholders,
        batch_size=FLAGS.batch_size,
        max_degree=FLAGS.max_degree,
        num_neg_samples=FLAGS.neg_sample_size,
        # context_pairs = context_pairs
    )
    adj_info_ph = tf.placeholder(tf.int32, shape=minibatch.adj.shape)
    adj_info = tf.Variable(adj_info_ph, trainable=False, name="adj_info")

    if FLAGS.model == 'graphsage_mean':
        # Create model
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   logging=True)
    elif FLAGS.model == 'gcn':
        # Create model
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, 2 * FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, 2 * FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   aggregator_type="gcn",
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   concat=False,
                                   logging=True)

    elif FLAGS.model == 'graphsage_seq':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   identity_dim=FLAGS.identity_dim,
                                   aggregator_type="seq",
                                   model_size=FLAGS.model_size,
                                   logging=True)

    elif FLAGS.model == 'graphsage_maxpool':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   aggregator_type="maxpool",
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   logging=True)
    elif FLAGS.model == 'graphsage_meanpool':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   aggregator_type="meanpool",
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   logging=True)

    elif FLAGS.model == 'n2v':
        model = Node2VecModel(
            placeholders,
            features.shape[0],
            minibatch.deg,
            #2x because graphsage uses concat
            nodevec_dim=2 * FLAGS.dim_1,
            lr=FLAGS.learning_rate)
    else:
        raise Exception('Error: model name unrecognized.')

    config = tf.ConfigProto(log_device_placement=FLAGS.log_device_placement)
    config.gpu_options.allow_growth = True
    #config.gpu_options.per_process_gpu_memory_fraction = GPU_MEM_FRACTION
    config.allow_soft_placement = True

    # Initialize session
    sess = tf.Session(config=config)
    merged = tf.summary.merge_all()
    summary_writer = tf.summary.FileWriter(log_dir(), sess.graph)

    # Init variables
    sess.run(tf.global_variables_initializer(),
             feed_dict={adj_info_ph: minibatch.adj})

    # Train model

    train_shadow_mrr = None
    shadow_mrr = None

    total_steps = 0
    avg_time = 0.0
    epoch_val_costs = []

    train_adj_info = tf.assign(adj_info, minibatch.adj)
    val_adj_info = tf.assign(adj_info, minibatch.test_adj)
    for epoch in range(FLAGS.epochs):
        # minibatch.shuffle()

        iter = 0
        print('Epoch: %04d' % (epoch + 1))
        epoch_val_costs.append(0)
        while not minibatch.end():
            # Construct feed dictionary
            feed_dict = minibatch.next_minibatch_feed_dict()
            feed_dict.update({placeholders['dropout']: FLAGS.dropout})

            t = time.time()
            # Training step
            outs = sess.run([
                merged, model.opt_op, model.loss, model.ranks, model.aff_all,
                model.mrr, model.outputs1
            ],
                            feed_dict=feed_dict)
            train_cost = outs[2]
            train_mrr = outs[5]
            if train_shadow_mrr is None:
                train_shadow_mrr = train_mrr  #
            else:
                train_shadow_mrr -= (1 - 0.99) * (train_shadow_mrr - train_mrr)

            if iter % FLAGS.validate_iter == 0:
                # Validation
                sess.run(val_adj_info.op)
                val_cost, ranks, val_mrr, duration = evaluate(
                    sess, model, minibatch, size=FLAGS.validate_batch_size)
                sess.run(train_adj_info.op)
                epoch_val_costs[-1] += val_cost
            if shadow_mrr is None:
                shadow_mrr = val_mrr
            else:
                shadow_mrr -= (1 - 0.99) * (shadow_mrr - val_mrr)

            if total_steps % FLAGS.print_every == 0:
                summary_writer.add_summary(outs[0], total_steps)

            # Print results
            avg_time = (avg_time * total_steps + time.time() -
                        t) / (total_steps + 1)

            if total_steps % FLAGS.print_every == 0:
                print(
                    "Iter:",
                    '%04d' % iter,
                    "train_loss=",
                    "{:.5f}".format(train_cost),
                    "train_mrr=",
                    "{:.5f}".format(train_mrr),
                    "train_mrr_ema=",
                    "{:.5f}".format(
                        train_shadow_mrr),  # exponential moving average
                    "val_loss=",
                    "{:.5f}".format(val_cost),
                    "val_mrr=",
                    "{:.5f}".format(val_mrr),
                    "val_mrr_ema=",
                    "{:.5f}".format(shadow_mrr),  # exponential moving average
                    "time=",
                    "{:.5f}".format(avg_time))

            iter += 1
            total_steps += 1

            if total_steps > FLAGS.max_total_steps:
                break

        if total_steps > FLAGS.max_total_steps:
            break

    print("Optimization Finished!")
Пример #21
0
def train(train_data, test_data=None):

    G = train_data[0]
    features = train_data[1]
    id_map = train_data[2]
    class_map = train_data[4]
    if isinstance(list(class_map.values())[0], list):
        num_classes = len(list(class_map.values())[0])
    else:
        num_classes = len(set(class_map.values()))

    if not features is None:
        # pad with dummy zero vector
        features = np.vstack([features, np.zeros((features.shape[1], ))])

    context_pairs = train_data[3] if FLAGS.random_context else None
    global placeholders
    placeholders = construct_placeholders(num_classes)

    # contruct both supervised and unsupervised minibatch iterators
    minibatch = SupervisedEdgeMinibatchIterator(
        G,
        id_map,
        placeholders,
        class_map,
        num_classes,
        batch_size=FLAGS.batch_size,
        max_degree=FLAGS.max_degree,
        context_pairs=context_pairs,
        complete_validation=FLAGS.complete_val)
    adj_info_ph = tf.placeholder(tf.int32, shape=minibatch.adj.shape)
    adj_info = tf.Variable(adj_info_ph, trainable=False, name="adj_info")
    label_adj_info_ph = tf.placeholder(tf.int32,
                                       shape=minibatch.label_adj.shape)
    label_adj_info = tf.Variable(label_adj_info_ph,
                                 trainable=False,
                                 name="label_adj_info")

    # Neighbors sampler
    if FLAGS.sampler == 'uniform':
        sampler = UniformNeighborSampler(adj_info)
    elif FLAGS.sampler == 'label_assisted':
        sampler = LabelAssistedNeighborSampler(adj_info, label_adj_info,
                                               FLAGS.topology_label_ratio)
    else:
        raise Exception('Error: sampler name unrecognized.')

    if FLAGS.model == 'graphsage_mean':
        # Layers definitions
        if FLAGS.samples_3 != 0:
            layer_infos = [
                SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2),
                SAGEInfo("node", sampler, FLAGS.samples_3, FLAGS.dim_2)
            ]
        elif FLAGS.samples_2 != 0:
            layer_infos = [
                SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
            ]
        else:
            layer_infos = [
                SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1)
            ]
        # Create model
        model = SemiSupervisedGraphsage(num_classes,
                                        placeholders,
                                        features,
                                        adj_info,
                                        minibatch.deg,
                                        layer_infos,
                                        aggregator_type="mean",
                                        model_size=FLAGS.model_size,
                                        sigmoid_loss=FLAGS.sigmoid,
                                        identity_dim=FLAGS.identity_dim,
                                        logging=True)
    elif FLAGS.model == 'gcn':
        # Layers definitions
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, 2 * FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, 2 * FLAGS.dim_2)
        ]
        # Create model
        model = SemiSupervisedGraphsage(num_classes,
                                        placeholders,
                                        features,
                                        adj_info,
                                        minibatch.deg,
                                        layer_infos,
                                        aggregator_type="gcn",
                                        model_size=FLAGS.model_size,
                                        sigmoid_loss=FLAGS.sigmoid,
                                        identity_dim=FLAGS.identity_dim,
                                        concat=False,
                                        logging=True)
    elif FLAGS.model == 'graphsage_seq':
        # Layers definitions
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]
        # Create model
        model = SemiSupervisedGraphsage(num_classes,
                                        placeholders,
                                        features,
                                        adj_info,
                                        minibatch.deg,
                                        layer_infos,
                                        aggregator_type="seq",
                                        model_size=FLAGS.model_size,
                                        sigmoid_loss=FLAGS.sigmoid,
                                        identity_dim=FLAGS.identity_dim,
                                        logging=True)
    elif FLAGS.model == 'graphsage_maxpool':
        # Layers definitions
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]
        # Create model
        model = SemiSupervisedGraphsage(num_classes,
                                        placeholders,
                                        features,
                                        adj_info,
                                        minibatch.deg,
                                        layer_infos,
                                        aggregator_type="maxpool",
                                        model_size=FLAGS.model_size,
                                        sigmoid_loss=FLAGS.sigmoid,
                                        identity_dim=FLAGS.identity_dim,
                                        logging=True)
    elif FLAGS.model == 'graphsage_meanpool':
        # Layers definitions
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]
        # Create model
        model = SemiSupervisedGraphsage(num_classes,
                                        placeholders,
                                        features,
                                        adj_info,
                                        minibatch.deg,
                                        layer_infos,
                                        aggregator_type="meanpool",
                                        model_size=FLAGS.model_size,
                                        sigmoid_loss=FLAGS.sigmoid,
                                        identity_dim=FLAGS.identity_dim,
                                        logging=True)
    else:
        raise Exception('Error: model name unrecognized.')

    config = tf.ConfigProto(log_device_placement=FLAGS.log_device_placement)
    config.gpu_options.allow_growth = True
    #config.gpu_options.per_process_gpu_memory_fraction = GPU_MEM_FRACTION
    config.allow_soft_placement = True

    # Initialize session
    log_dir = get_log_dir()
    sess = tf.Session(config=config)

    val_loss_sup = tf.Variable(0., trainable=False, name="val_loss_sup")
    val_loss_unsup = tf.Variable(0., trainable=False, name="val_loss_unsup")
    val_mrr_var = tf.Variable(0., trainable=False, name="val_mrr")
    with tf.name_scope("train"):
        summary_train_loss_sup = tf.summary.scalar('supervised loss',
                                                   model.loss_sup)
        summary_train_loss_unsup = tf.summary.scalar('unsupervised loss',
                                                     model.loss_unsup)
        summary_train_mrr = tf.summary.scalar('mrr', model.mrr)
        summary_train_acc = tf.summary.scalar('accuracy', model.accuracy)
        summary_train_f1 = tf.summary.scalar('f1 score', model.f1)
        summary_train_confusion = tf.summary.image(
            'confusion',
            tf.reshape(tf.cast(model.confusion_read, tf.float32),
                       [1, num_classes, num_classes, 1]))
        summary_train_sup = tf.summary.merge([
            summary_train_loss_sup, summary_train_mrr, summary_train_acc,
            summary_train_f1, summary_train_confusion
        ])
        summary_train_unsup = tf.summary.merge(
            [summary_train_loss_unsup, summary_train_mrr])
    with tf.name_scope("val"):
        summary_val_loss_sup = tf.summary.scalar('supervised loss',
                                                 val_loss_sup)
        summary_val_loss_unsup = tf.summary.scalar('unsupervised loss',
                                                   val_loss_unsup)
        summary_val_mrr = tf.summary.scalar('mrr', val_mrr_var)
        summary_val_acc = tf.summary.scalar(
            'accuracy', model.accuracy_read_val
        )  # only read the already computed validation accuracy
        summary_val_f1 = tf.summary.scalar(
            'f1 score', model.f1_read_val
        )  # only read the already computed validation f1 score
        summary_val_confusion = tf.summary.image(
            'confusion',
            tf.reshape(tf.cast(model.confusion_read_val, tf.float32),
                       [1, num_classes, num_classes, 1]))
        summary_val = tf.summary.merge([
            summary_val_loss_sup, summary_val_loss_unsup, summary_val_mrr,
            summary_val_acc, summary_val_f1, summary_val_confusion
        ])
    summary_writer = tf.summary.FileWriter(log_dir, sess.graph)

    # Init variables
    sess.run(tf.global_variables_initializer(),
             feed_dict={
                 adj_info_ph: minibatch.adj,
                 label_adj_info_ph: minibatch.label_adj
             })

    # Train model

    total_steps = 0
    avg_time = 0.0

    train_adj_info = tf.assign(adj_info, minibatch.adj)
    val_adj_info = tf.assign(adj_info, minibatch.test_adj)
    train_label_adj_info = tf.assign(label_adj_info, minibatch.label_adj)
    val_label_adj_info = tf.assign(label_adj_info, minibatch.test_label_adj)
    for epoch in range(FLAGS.epochs):
        minibatch.shuffle()  # shuffle the minibatches

        # init local variables
        sess.run(tf.initializers.variables(tf.local_variables(scope="train")))
        sess.run(tf.initializers.variables(tf.local_variables(scope="val")))

        iter = 0
        print('Epoch: %04d' % (epoch + 1))

        while not (minibatch.end() and
                   (minibatch.end_sup() or FLAGS.supervised_ratio == 0)):
            # define supervised or unsupervised training
            supervised = False
            if (minibatch.end()
                    or (not minibatch.end_sup() and
                        (np.random.rand() < FLAGS.supervised_ratio))):
                supervised = True
            if supervised:
                loss = model.loss_sup
                optimizer = model.sup_opt_op
            else:
                loss = model.loss_unsup
                optimizer = model.unsup_opt_op
            # Construct feed dictionary
            feed_dict, labels = (minibatch.next_minibatch_feed_dict_sup()
                                 if supervised else
                                 minibatch.next_minibatch_feed_dict())
            feed_dict.update({placeholders['dropout']:
                              FLAGS.dropout})  # TEMP: change placeholder
            feed_dict.update({placeholders['pos_class']: FLAGS.pos_class})

            # Training step
            t = time.time()
            if supervised:
                summary, _, train_cost, train_mrr, preds, confusion = sess.run(
                    [
                        summary_train_sup,
                        optimizer,  # otimization operation
                        loss,  # compute current loss
                        model.mrr,
                        model.preds,  # compute predictions for inputs
                        model.confusion
                    ],
                    feed_dict=feed_dict)
            else:
                summary, _, train_cost, train_mrr, preds, confusion = sess.run(
                    [
                        summary_train_unsup,
                        optimizer,  # otimization operation
                        loss,  # compute current loss
                        model.mrr,
                        model.preds,  # compute predictions for inputs
                        model.confusion_read
                    ],  # only read confusion matrix for unsupervised iterations
                    feed_dict=feed_dict)

            if iter % FLAGS.validate_iter == 0:
                # Validation
                sess.run([val_adj_info.op, val_label_adj_info.op])
                if FLAGS.validate_batch_size == -1:
                    val_cost_sup, val_cost_unsup, val_mrr, val_acc, val_f1, val_confusion, duration = incremental_evaluate(
                        sess, model, minibatch, FLAGS.batch_size)
                else:
                    val_cost_sup, val_cost_unsup, val_mrr, val_acc, val_f1, val_confusion, duration = evaluate(
                        sess,
                        model,
                        minibatch,
                        FLAGS.validate_batch_size,
                        supervised=supervised)

                # log validation summary
                if val_cost_sup is not None:
                    sess.run(val_loss_sup.assign(val_cost_sup))
                if val_cost_unsup is not None:
                    sess.run(val_loss_unsup.assign(val_cost_unsup))
                if val_cost_unsup is not None:
                    sess.run(val_mrr_var.assign(val_mrr))
                summary_val_out = sess.run(summary_val, feed_dict=feed_dict)
                summary_writer.add_summary(summary_val_out, total_steps)

                # print validation stats
                print_iter(type="VAL",
                           epoch=epoch + 1,
                           iter=iter,
                           total_steps=total_steps,
                           loss_sup=val_cost_sup,
                           loss_unsup=val_cost_unsup,
                           mrr=val_mrr,
                           f1=val_f1,
                           accuracy=val_acc,
                           confusion=val_confusion)

                sess.run([train_adj_info.op, train_label_adj_info.op])

            # log train summary
            summary_writer.add_summary(summary, total_steps)

            # running average for training time
            avg_time = (avg_time * total_steps + time.time() -
                        t) / (total_steps + 1)

            # Print training iteration results
            if total_steps % FLAGS.print_every == 0:
                print_iter(type=("SUP" if supervised else "UNS"),
                           epoch=epoch + 1,
                           iter=iter,
                           total_steps=total_steps,
                           loss_sup=(train_cost if supervised else None),
                           loss_unsup=(None if supervised else train_cost),
                           mrr=train_mrr,
                           f1=(sess.run(model.f1_read, feed_dict=feed_dict)
                               if supervised else None),
                           accuracy=(sess.run(model.accuracy_read,
                                              feed_dict=feed_dict)
                                     if supervised else None),
                           confusion=(confusion if supervised else None))

            # update counters
            iter += 1
            total_steps += 1

            if total_steps > FLAGS.max_total_steps:
                break

        if total_steps > FLAGS.max_total_steps:
            break

    print("Optimization Finished!")

    # compute final validation results
    sess.run([val_adj_info.op, val_label_adj_info.op])
    val_cost_sup, val_cost_unsup, val_mrr, val_acc, val_f1, val_confusion, duration = incremental_evaluate(
        sess, model, minibatch, FLAGS.batch_size)

    # log final results
    if val_cost_sup is not None:
        sess.run(val_loss_sup.assign(val_cost_sup))
    if val_cost_unsup is not None:
        sess.run(val_loss_unsup.assign(val_cost_unsup))
    if val_cost_unsup is not None:
        sess.run(val_mrr_var.assign(val_mrr))
    summary_val_out = sess.run(summary_val, feed_dict=feed_dict)
    summary_writer.add_summary(summary_val_out, total_steps)

    # print final results
    print("Full validation stats:\n", "\tsupervised loss=",
          "{:.5f}".format(val_cost_sup), "\n", "\tunsupervised loss=",
          "{:.5f}".format(val_cost_unsup), "\n", "\tmrr=",
          "{:.5f}".format(val_mrr), "\n", "\taccuracy=",
          "{:.5f}".format(val_acc), "\n", "\tf1-score=",
          "{:.5f}".format(val_f1), "\n", "\tevaluation time=",
          "{:.5f}".format(duration), "\n", "\taverage_training_time=",
          "{:.5f}".format(avg_time))
    if FLAGS.print_confusion:
        print("confusion= \n{:s}".format(val_confusion))
    # write an output file
    with open(log_dir + "val_stats.txt", "w") as fp:
        fp.write(
            "supervised_loss={:.5f}, unsupervised_loss={:.5f}, mrr={:.5f}, accuracy={:.5f}, f1-score={:.5f}, evaluation_time={:.5f}, iteration_training_time={:.5f}"
            .format(val_cost_sup, val_cost_unsup, val_mrr, val_acc, val_f1,
                    duration, avg_time))

    with open(log_dir + "command.txt", "w") as fp:
        fp.write(str(FLAGS.flag_values_dict()))

    # TODO: Perform evaluation on test set

    if FLAGS.save_embeddings:
        print("Saving embeddings..")
        sess.run([val_adj_info.op, val_label_adj_info.op])
        save_val_embeddings(sess, model, minibatch, FLAGS.validate_batch_size,
                            log_dir)
Пример #22
0
def train(train_data, test_data=None):
    G = train_data[0]
    features = train_data[1]
    id_map = train_data[2]

    if not features is None:
        # pad with dummy zero vector
        features = np.vstack([features, np.zeros((features.shape[1], ))])

    context_pairs = train_data[3] if FLAGS.random_context else None
    placeholders = construct_placeholders()
    minibatch = EdgeMinibatchIterator(G,
                                      id_map,
                                      placeholders,
                                      batch_size=FLAGS.batch_size,
                                      max_degree=FLAGS.max_degree,
                                      num_neg_samples=FLAGS.neg_sample_size,
                                      context_pairs=context_pairs)
    adj_info_ph = tf.compat.v1.placeholder(tf.int32, shape=minibatch.adj.shape)
    adj_info = tf.Variable(adj_info_ph, trainable=False, name="adj_info")

    if FLAGS.model == 'graphsage_mean':
        # Create model
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   logging=True)
    elif FLAGS.model == 'gcn':
        # Create model
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, 2 * FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, 2 * FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   aggregator_type="gcn",
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   concat=False,
                                   logging=True)

    elif FLAGS.model == 'graphsage_seq':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   identity_dim=FLAGS.identity_dim,
                                   aggregator_type="seq",
                                   model_size=FLAGS.model_size,
                                   logging=True)

    elif FLAGS.model == 'graphsage_maxpool':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   aggregator_type="maxpool",
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   logging=True)
    elif FLAGS.model == 'graphsage_meanpool':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   aggregator_type="meanpool",
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   logging=True)

    elif FLAGS.model == 'n2v':
        model = Node2VecModel(
            placeholders,
            features.shape[0],
            minibatch.deg,
            #2x because graphsage uses concat
            nodevec_dim=2 * FLAGS.dim_1,
            lr=FLAGS.learning_rate)
    else:
        raise Exception('Error: model name unrecognized.')

    config = tf.compat.v1.ConfigProto(
        log_device_placement=FLAGS.log_device_placement)
    config.gpu_options.allow_growth = True
    #config.gpu_options.per_process_gpu_memory_fraction = GPU_MEM_FRACTION
    config.allow_soft_placement = True

    # Initialize WandB experiment
    wandb.init(project='GraphSAGE_trial',
               entity='cvl-liu-01',
               save_code=True,
               tags=['unsupervised'])
    wandb.config.update(flags.FLAGS)

    # Initialize session
    sess = tf.compat.v1.Session(config=config)
    merged = tf.compat.v1.summary.merge_all()
    summary_writer = tf.compat.v1.summary.FileWriter(log_dir(), sess.graph)

    # Init variables
    sess.run(tf.compat.v1.global_variables_initializer(),
             feed_dict={adj_info_ph: minibatch.adj})

    # Init saver
    saver = tf.compat.v1.train.Saver(max_to_keep=8,
                                     keep_checkpoint_every_n_hours=1)

    # Train model
    train_shadow_mrr = None
    val_shadow_mrr = None

    total_steps = 0
    avg_time = 0.0
    epoch_val_costs = []

    train_adj_info = tf.compat.v1.assign(adj_info, minibatch.adj)
    val_adj_info = tf.compat.v1.assign(adj_info, minibatch.test_adj)
    for epoch in range(FLAGS.epochs):
        minibatch.shuffle()

        iter = 0
        print('Epoch: %04d' % (epoch + 1))
        epoch_val_costs.append(0)
        while not minibatch.end():
            # Construct feed dictionary
            feed_dict = minibatch.next_minibatch_feed_dict()
            feed_dict.update({placeholders['dropout']: FLAGS.dropout})

            t = time.time()
            # Training step
            outs = sess.run([
                merged, model.opt_op, model.loss, model.ranks, model.aff_all,
                model.mrr, model.outputs1
            ],
                            feed_dict=feed_dict)
            train_cost = outs[2]
            train_mrr = outs[5]
            if train_shadow_mrr is None:
                train_shadow_mrr = train_mrr  #
            else:
                train_shadow_mrr -= (1 - 0.99) * (train_shadow_mrr - train_mrr)

            # Validation
            if iter % FLAGS.validate_iter == 0:
                sess.run(val_adj_info.op)
                val_cost, ranks, val_mrr, duration = evaluate(
                    sess, model, minibatch, size=FLAGS.validate_batch_size)
                sess.run(train_adj_info.op)
                epoch_val_costs[-1] += val_cost
            if val_shadow_mrr is None:
                val_shadow_mrr = val_mrr
            else:
                val_shadow_mrr -= (1 - 0.99) * (val_shadow_mrr - val_mrr)

            if total_steps % FLAGS.print_every == 0:
                summary_writer.add_summary(outs[0], total_steps)

            # Print results
            avg_time = (avg_time * total_steps + time.time() -
                        t) / (total_steps + 1)

            if total_steps % FLAGS.print_every == 0:
                print(
                    "[%03d/%03d]" % (epoch + 1, FLAGS.epochs),
                    "Iter:",
                    '[%05d/%05d]' % (iter, minibatch.num_training_batches()),
                    "train_loss =",
                    "{:.5f}".format(train_cost),
                    "train_mrr =",
                    "{:.5f}".format(train_mrr),
                    "train_mrr_ema =",
                    "{:.5f}".format(
                        train_shadow_mrr),  # exponential moving average
                    "val_loss =",
                    "{:.5f}".format(val_cost),
                    "val_mrr =",
                    "{:.5f}".format(val_mrr),
                    "val_mrr_ema =",
                    "{:.5f}".format(
                        val_shadow_mrr),  # exponential moving average
                    "time =",
                    "{:.5f}".format(avg_time))

            # W&B Logging
            if FLAGS.wandb_log and iter % FLAGS.wandb_log_iter == 0:
                wandb.log({'train_loss': train_cost, 'epoch': epoch})
                wandb.log({'train_mrr': train_mrr, 'epoch': epoch})
                wandb.log({'train_mrr_ema': train_shadow_mrr, 'epoch': epoch})
                wandb.log({'val_loss': val_cost, 'epoch': epoch})
                wandb.log({'val_mrr': val_mrr, 'epoch': epoch})
                wandb.log({'val_mrr_ema': val_shadow_mrr, 'epoch': epoch})
                wandb.log({'time': avg_time, 'epoch': epoch})

            iter += 1
            total_steps += 1

            if total_steps > FLAGS.max_total_steps:
                print('Max total steps reached!')
                break

        # Save embeddings
        if FLAGS.save_embeddings and epoch % FLAGS.save_embeddings_epoch == 0:
            save_val_embeddings(sess, model, minibatch,
                                FLAGS.validate_batch_size, log_dir())

            # Also report classifier metric on the embedding
            all_tr_res, all_ts_res = osm_classif_eval.evaluate(
                FLAGS.train_prefix, log_dir(), n_iter=FLAGS.classif_n_iter)
            if FLAGS.wandb_log:
                wandb.log(all_tr_res)
                wandb.log(all_ts_res)

        # Save Model checkpoints
        if FLAGS.save_checkpoints and epoch % FLAGS.save_checkpoints_epoch == 0:
            # saver.save(sess, log_dir() + 'model', global_step=1000)
            print('Save model checkpoint:', wandb.run.dir, iter, total_steps,
                  epoch)
            saver.save(
                sess,
                os.path.join(wandb.run.dir,
                             "model-" + str(epoch + 1) + ".ckpt"))

        if total_steps > FLAGS.max_total_steps:
            print('Max total steps reached!')
            break

    print("Optimization Finished!")
    if FLAGS.save_embeddings:
        sess.run(val_adj_info.op)

        save_val_embeddings(sess, model, minibatch, FLAGS.validate_batch_size,
                            log_dir())

        if FLAGS.model == "n2v":
            # stopping the gradient for the already trained nodes
            train_ids = tf.constant(
                [[id_map[n]] for n in G.nodes_iter()
                 if not G.node[n]['val'] and not G.node[n]['test']],
                dtype=tf.int32)
            test_ids = tf.constant([[id_map[n]] for n in G.nodes_iter()
                                    if G.node[n]['val'] or G.node[n]['test']],
                                   dtype=tf.int32)
            update_nodes = tf.nn.embedding_lookup(model.context_embeds,
                                                  tf.squeeze(test_ids))
            no_update_nodes = tf.nn.embedding_lookup(model.context_embeds,
                                                     tf.squeeze(train_ids))
            update_nodes = tf.scatter_nd(test_ids, update_nodes,
                                         tf.shape(model.context_embeds))
            no_update_nodes = tf.stop_gradient(
                tf.scatter_nd(train_ids, no_update_nodes,
                              tf.shape(model.context_embeds)))
            model.context_embeds = update_nodes + no_update_nodes
            sess.run(model.context_embeds)

            # run random walks
            from graphsage.utils import run_random_walks
            nodes = [
                n for n in G.nodes_iter()
                if G.node[n]["val"] or G.node[n]["test"]
            ]
            start_time = time.time()
            pairs = run_random_walks(G, nodes, num_walks=50)
            walk_time = time.time() - start_time

            test_minibatch = EdgeMinibatchIterator(
                G,
                id_map,
                placeholders,
                batch_size=FLAGS.batch_size,
                max_degree=FLAGS.max_degree,
                num_neg_samples=FLAGS.neg_sample_size,
                context_pairs=pairs,
                n2v_retrain=True,
                fixed_n2v=True)

            start_time = time.time()
            print("Doing test training for n2v.")
            test_steps = 0
            for epoch in range(FLAGS.n2v_test_epochs):
                test_minibatch.shuffle()
                while not test_minibatch.end():
                    feed_dict = test_minibatch.next_minibatch_feed_dict()
                    feed_dict.update({placeholders['dropout']: FLAGS.dropout})
                    outs = sess.run([
                        model.opt_op, model.loss, model.ranks, model.aff_all,
                        model.mrr, model.outputs1
                    ],
                                    feed_dict=feed_dict)
                    if test_steps % FLAGS.print_every == 0:
                        print("Iter:", '%04d' % test_steps, "train_loss=",
                              "{:.5f}".format(outs[1]), "train_mrr=",
                              "{:.5f}".format(outs[-2]))
                    test_steps += 1
            train_time = time.time() - start_time
            save_val_embeddings(sess,
                                model,
                                minibatch,
                                FLAGS.validate_batch_size,
                                log_dir(),
                                mod="-test")
            print("Total time: ", train_time + walk_time)
            print("Walk time: ", walk_time)
            print("Train time: ", train_time)
Пример #23
0
def train(train_data, test_data=None, sampler_name='Uniform'):
    
    G = train_data[0]
    features = train_data[1]
    id_map = train_data[2]

    if not features is None:
        # pad with dummy zero vector
        features = np.vstack([features, np.zeros((features.shape[1],))])

    context_pairs = train_data[3] if FLAGS.random_context else None
    placeholders = construct_placeholders()
    minibatch = EdgeMinibatchIterator(G, 
            id_map,
            placeholders, batch_size=FLAGS.batch_size,
            max_degree=FLAGS.max_degree, 
            num_neg_samples=FLAGS.neg_sample_size,
            context_pairs = context_pairs)
    
    adj_info_ph = tf.placeholder(tf.int32, shape=minibatch.adj.shape)
    adj_info = tf.Variable(adj_info_ph, trainable=False, name="adj_info")
    adj_shape = adj_info.get_shape().as_list()

    if FLAGS.model == 'mean_concat':
        # Create model

        if sampler_name == 'Uniform':
            sampler = UniformNeighborSampler(adj_info)
        elif sampler_name == 'ML':
            sampler = MLNeighborSampler(adj_info, features)
        elif sampler_name == 'FastML':
            sampler = FastMLNeighborSampler(adj_info, features)


        #sampler = UniformNeighborSampler(adj_info)
        layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)]

        model = SampleAndAggregate(placeholders, 
                                     features,
                                     adj_info,
                                     minibatch.deg,
                                     concat=True,
                                     layer_infos=layer_infos, 
                                     model_size=FLAGS.model_size,
                                     identity_dim = FLAGS.identity_dim,
                                     logging=True)
    
    elif FLAGS.model == 'mean_add':
        # Create model

        if sampler_name == 'Uniform':
            sampler = UniformNeighborSampler(adj_info)
        elif sampler_name == 'ML':
            sampler = MLNeighborSampler(adj_info, features)
        elif sampler_name == 'FastML':
            sampler = FastMLNeighborSampler(adj_info, features)


        #sampler = UniformNeighborSampler(adj_info)
        layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)]

        model = SampleAndAggregate(placeholders, 
                                     features,
                                     adj_info,
                                     minibatch.deg,
                                     concat=False,
                                     layer_infos=layer_infos, 
                                     model_size=FLAGS.model_size,
                                     identity_dim = FLAGS.identity_dim,
                                     logging=True)

    elif FLAGS.model == 'gcn':

        if sampler_name == 'Uniform':
            sampler = UniformNeighborSampler(adj_info)
        elif sampler_name == 'ML':
            sampler = MLNeighborSampler(adj_info, features)
        elif sampler_name == 'FastML':
            sampler = FastMLNeighborSampler(adj_info, features)

           
        # Create model
        #sampler = UniformNeighborSampler(adj_info)
        layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, 2*FLAGS.dim_1),
                            SAGEInfo("node", sampler, FLAGS.samples_2, 2*FLAGS.dim_2)]

        model = SampleAndAggregate(placeholders, 
                                     features,
                                     adj_info,
                                     minibatch.deg,
                                     layer_infos=layer_infos, 
                                     aggregator_type="gcn",
                                     model_size=FLAGS.model_size,
                                     identity_dim = FLAGS.identity_dim,
                                     concat=False,
                                     logging=True)

    elif FLAGS.model == 'graphsage_seq':
        
        if sampler_name == 'Uniform':
            sampler = UniformNeighborSampler(adj_info)
        elif sampler_name == 'ML':
            sampler = MLNeighborSampler(adj_info, features)
        elif sampler_name == 'FastML':
            sampler = FastMLNeighborSampler(adj_info, features)


        #sampler = UniformNeighborSampler(adj_info)
        layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)]

        model = SampleAndAggregate(placeholders, 
                                     features,
                                     adj_info,
                                     minibatch.deg,
                                     layer_infos=layer_infos, 
                                     identity_dim = FLAGS.identity_dim,
                                     aggregator_type="seq",
                                     model_size=FLAGS.model_size,
                                     logging=True)

    elif FLAGS.model == 'graphsage_maxpool':
        
        if sampler_name == 'Uniform':
            sampler = UniformNeighborSampler(adj_info)
        elif sampler_name == 'ML':
            sampler = MLNeighborSampler(adj_info, features)
        elif sampler_name == 'FastML':
            sampler = FastMLNeighborSampler(adj_info, features)
    
            
        #sampler = UniformNeighborSampler(adj_info)
        layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)]

        model = SampleAndAggregate(placeholders, 
                                    features,
                                    adj_info,
                                    minibatch.deg,
                                     layer_infos=layer_infos, 
                                     aggregator_type="maxpool",
                                     model_size=FLAGS.model_size,
                                     identity_dim = FLAGS.identity_dim,
                                     logging=True)
    elif FLAGS.model == 'graphsage_meanpool':
        
        if sampler_name == 'Uniform':
            sampler = UniformNeighborSampler(adj_info)
        elif sampler_name == 'ML':
            sampler = MLNeighborSampler(adj_info, features)
        elif sampler_name == 'FastML':
            sampler = FastMLNeighborSampler(adj_info, features)
            
        #sampler = UniformNeighborSampler(adj_info)
        layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)]

        model = SampleAndAggregate(placeholders, 
                                    features,
                                    adj_info,
                                    minibatch.deg,
                                     layer_infos=layer_infos, 
                                     aggregator_type="meanpool",
                                     model_size=FLAGS.model_size,
                                     identity_dim = FLAGS.identity_dim,
                                     logging=True)

    elif FLAGS.model == 'n2v':
        model = Node2VecModel(placeholders, features.shape[0],
                                       minibatch.deg,
                                       #2x because graphsage uses concat
                                       nodevec_dim=2*FLAGS.dim_1,
                                       lr=FLAGS.learning_rate)
    else:
        raise Exception('Error: model name unrecognized.')

    config = tf.ConfigProto(log_device_placement=FLAGS.log_device_placement)
    config.gpu_options.allow_growth = True
    #config.gpu_options.per_process_gpu_memory_fraction = GPU_MEM_FRACTION
    config.allow_soft_placement = True
    
    # Initialize session
    sess = tf.Session(config=config)
    merged = tf.summary.merge_all()
    summary_writer = tf.summary.FileWriter(log_dir(sampler_name), sess.graph)
    
    
    # Save model
    saver = tf.train.Saver()
    model_path =  './model/unsup-' + FLAGS.train_prefix.split('/')[-1] + '-' + FLAGS.model_prefix + '-' + sampler_name
    model_path += "/{model:s}_{model_size:s}_{lr:0.4f}/".format(
            model=FLAGS.model,
            model_size=FLAGS.model_size,
            lr=FLAGS.learning_rate)

    if not os.path.exists(model_path):
        os.makedirs(model_path)

    
    # Init variables
    sess.run(tf.global_variables_initializer(), feed_dict={adj_info_ph: minibatch.adj})
    
    
    # Restore params of ML sampler model
    if sampler_name == 'ML' or sampler_name == 'FastML':
        sampler_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="MLsampler")
        #pdb.set_trace() 
        saver_sampler = tf.train.Saver(var_list=sampler_vars)
        sampler_model_path = './model/MLsampler-unsup-' + FLAGS.train_prefix.split('/')[-1] + '-' + FLAGS.model_prefix
        sampler_model_path += "/{model:s}_{model_size:s}_{lr:0.4f}/".format(
            model=FLAGS.model,
            model_size=FLAGS.model_size,
            lr=FLAGS.learning_rate)

        saver_sampler.restore(sess, sampler_model_path + 'model.ckpt')

   
    # Train model
    
    train_shadow_mrr = None
    shadow_mrr = None

    total_steps = 0
    avg_time = 0.0
    epoch_val_costs = []

    train_adj_info = tf.assign(adj_info, minibatch.adj)
    val_adj_info = tf.assign(adj_info, minibatch.test_adj)

    val_cost_ = []
    val_mrr_ = []
    shadow_mrr_ = []
    duration_ = []
    
    ln_acc = sparse.csr_matrix((adj_shape[0], adj_shape[0]), dtype=np.float32)
    lnc_acc = sparse.csr_matrix((adj_shape[0], adj_shape[0]), dtype=np.int32)
    
    ln_acc = ln_acc.tolil()
    lnc_acc = lnc_acc.tolil()

    for epoch in range(FLAGS.epochs): 
        minibatch.shuffle() 

        iter = 0
        print('Epoch: %04d' % (epoch + 1))
        epoch_val_costs.append(0)
        
        while not minibatch.end():
            # Construct feed dictionary
            feed_dict = minibatch.next_minibatch_feed_dict()

            if feed_dict.values()[0] != FLAGS.batch_size:
                break

            feed_dict.update({placeholders['dropout']: FLAGS.dropout})

            t = time.time()
            # Training step
            outs = sess.run([merged, model.opt_op, model.loss, model.ranks, model.aff_all, 
                    model.mrr, model.outputs1, model.loss_node, model.loss_node_count], feed_dict=feed_dict)
            train_cost = outs[2]
            train_mrr = outs[5]
            
            if train_shadow_mrr is None:
                train_shadow_mrr = train_mrr#
            else:
                train_shadow_mrr -= (1-0.99) * (train_shadow_mrr - train_mrr)

            if iter % FLAGS.validate_iter == 0:
                # Validation
                sess.run(val_adj_info.op)
                val_cost, ranks, val_mrr, duration  = evaluate(sess, model, minibatch, size=FLAGS.validate_batch_size)
                sess.run(train_adj_info.op)
                epoch_val_costs[-1] += val_cost
            
            if shadow_mrr is None:
                shadow_mrr = val_mrr
            else:
                shadow_mrr -= (1-0.99) * (shadow_mrr - val_mrr)
            
            val_cost_.append(val_cost)
            val_mrr_.append(val_mrr)
            shadow_mrr_.append(shadow_mrr)
            duration_.append(duration)


            if total_steps % FLAGS.print_every == 0:
                summary_writer.add_summary(outs[0], total_steps)
    
            # Print results
            avg_time = (avg_time * total_steps + time.time() - t) / (total_steps + 1)

            if total_steps % FLAGS.print_every == 0:
                print("Iter:", '%04d' % iter, 
                      "train_loss=", "{:.5f}".format(train_cost),
                      "train_mrr=", "{:.5f}".format(train_mrr), 
                      "train_mrr_ema=", "{:.5f}".format(train_shadow_mrr), # exponential moving average
                      "val_loss=", "{:.5f}".format(val_cost),
                      "val_mrr=", "{:.5f}".format(val_mrr), 
                      "val_mrr_ema=", "{:.5f}".format(shadow_mrr), # exponential moving average
                      "time=", "{:.5f}".format(avg_time))

            
            ln = outs[7].values
            ln_idx = outs[7].indices
            ln_acc[ln_idx[:,0], ln_idx[:,1]] += ln

            lnc = outs[8].values
            lnc_idx = outs[8].indices
            lnc_acc[lnc_idx[:,0], lnc_idx[:,1]] += lnc
                
                
            iter += 1
            total_steps += 1

            if total_steps > FLAGS.max_total_steps:
                break

        if total_steps > FLAGS.max_total_steps:
                break


    print("Validation per epoch in training")
    for ep in range(FLAGS.epochs):
        print("Epoch: %04d"%ep, " val_cost={:.5f}".format(val_cost_[ep]), " val_mrr={:.5f}".format(val_mrr_[ep]), " val_mrr_ema={:.5f}".format(shadow_mrr_[ep]), " duration={:.5f}".format(duration_[ep]))
 
    print("Optimization Finished!")
    
    # Save model
    save_path = saver.save(sess, model_path+'model.ckpt')
    print ('model is saved at %s'%save_path)


    # Save loss node and count
    loss_node_path = './loss_node/unsup-' + FLAGS.train_prefix.split('/')[-1] + '-' + FLAGS.model_prefix + '-' + sampler_name
    loss_node_path += "/{model:s}_{model_size:s}_{lr:0.4f}/".format(
            model=FLAGS.model,
            model_size=FLAGS.model_size,
            lr=FLAGS.learning_rate)
    if not os.path.exists(loss_node_path):
        os.makedirs(loss_node_path)

    loss_node = sparse.save_npz(loss_node_path + 'loss_node.npz', sparse.csr_matrix(ln_acc))
    loss_node_count = sparse.save_npz(loss_node_path + 'loss_node_count.npz', sparse.csr_matrix(lnc_acc))
    print ('loss and count per node is saved at %s'%loss_node_path)    
    
    
    
    if FLAGS.save_embeddings:
        sess.run(val_adj_info.op)

        save_val_embeddings(sess, model, minibatch, FLAGS.validate_batch_size, log_dir(sampler_name))

        if FLAGS.model == "n2v":
            # stopping the gradient for the already trained nodes
            train_ids = tf.constant([[id_map[n]] for n in G.nodes_iter() if not G.node[n]['val'] and not G.node[n]['test']],
                    dtype=tf.int32)
            test_ids = tf.constant([[id_map[n]] for n in G.nodes_iter() if G.node[n]['val'] or G.node[n]['test']], 
                    dtype=tf.int32)
            update_nodes = tf.nn.embedding_lookup(model.context_embeds, tf.squeeze(test_ids))
            no_update_nodes = tf.nn.embedding_lookup(model.context_embeds,tf.squeeze(train_ids))
            update_nodes = tf.scatter_nd(test_ids, update_nodes, tf.shape(model.context_embeds))
            no_update_nodes = tf.stop_gradient(tf.scatter_nd(train_ids, no_update_nodes, tf.shape(model.context_embeds)))
            model.context_embeds = update_nodes + no_update_nodes
            sess.run(model.context_embeds)

            # run random walks
            from graphsage.utils import run_random_walks
            nodes = [n for n in G.nodes_iter() if G.node[n]["val"] or G.node[n]["test"]]
            start_time = time.time()
            pairs = run_random_walks(G, nodes, num_walks=50)
            walk_time = time.time() - start_time

            test_minibatch = EdgeMinibatchIterator(G, 
                id_map,
                placeholders, batch_size=FLAGS.batch_size,
                max_degree=FLAGS.max_degree, 
                num_neg_samples=FLAGS.neg_sample_size,
                context_pairs = pairs,
                n2v_retrain=True,
                fixed_n2v=True)
            
            start_time = time.time()
            print("Doing test training for n2v.")
            test_steps = 0
            for epoch in range(FLAGS.n2v_test_epochs):
                test_minibatch.shuffle()
                while not test_minibatch.end():
                    feed_dict = test_minibatch.next_minibatch_feed_dict()
                    feed_dict.update({placeholders['dropout']: FLAGS.dropout})
                    outs = sess.run([model.opt_op, model.loss, model.ranks, model.aff_all, 
                        model.mrr, model.outputs1], feed_dict=feed_dict)
                    if test_steps % FLAGS.print_every == 0:
                        print("Iter:", '%04d' % test_steps, 
                              "train_loss=", "{:.5f}".format(outs[1]),
                              "train_mrr=", "{:.5f}".format(outs[-2]))
                    test_steps += 1
            train_time = time.time() - start_time
            save_val_embeddings(sess, model, minibatch, FLAGS.validate_batch_size, log_dir(sampler_name), mod="-test")
            print("Total time: ", train_time+walk_time)
            print("Walk time: ", walk_time)
            print("Train time: ", train_time)
Пример #24
0
def train(train_data, test_data=None):
    G = train_data[0]
    features = train_data[1]
    id_map = train_data[2]
    class_map = train_data[4]
    if isinstance(list(class_map.values())[0], list):
        num_classes = len(list(class_map.values())[0])
    else:
        num_classes = len(set(class_map.values()))

    if not features is None:
        # pad with dummy zero vector
        features = np.vstack([features, np.zeros((features.shape[1], ))])

    context_pairs = train_data[3] if FLAGS.random_context else None
    placeholders = construct_placeholders(num_classes)
    minibatch = NodeMinibatchIterator(G,
                                      id_map,
                                      placeholders,
                                      class_map,
                                      num_classes,
                                      batch_size=FLAGS.batch_size,
                                      max_degree=FLAGS.max_degree,
                                      context_pairs=context_pairs)
    adj_info_ph = tf.placeholder(tf.int32, shape=minibatch.adj.shape)
    adj_info = tf.Variable(adj_info_ph, trainable=False, name="adj_info")

    if FLAGS.model == 'graphsage_mean':
        # Create model
        sampler = UniformNeighborSampler(adj_info)
        if FLAGS.samples_3 != 0:
            layer_infos = [
                SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2),
                SAGEInfo("node", sampler, FLAGS.samples_3, FLAGS.dim_2)
            ]
        elif FLAGS.samples_2 != 0:
            layer_infos = [
                SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
            ]
        else:
            layer_infos = [
                SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1)
            ]

        model = SupervisedGraphsage(num_classes,
                                    placeholders,
                                    features,
                                    adj_info,
                                    minibatch.deg,
                                    layer_infos,
                                    model_size=FLAGS.model_size,
                                    sigmoid_loss=FLAGS.sigmoid,
                                    identity_dim=FLAGS.identity_dim,
                                    logging=True)
    elif FLAGS.model == 'gcn':
        # Create model
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, 2 * FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, 2 * FLAGS.dim_2)
        ]

        model = SupervisedGraphsage(num_classes,
                                    placeholders,
                                    features,
                                    adj_info,
                                    minibatch.deg,
                                    layer_infos=layer_infos,
                                    aggregator_type="gcn",
                                    model_size=FLAGS.model_size,
                                    concat=False,
                                    sigmoid_loss=FLAGS.sigmoid,
                                    identity_dim=FLAGS.identity_dim,
                                    logging=True)

    elif FLAGS.model == 'graphsage_seq':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SupervisedGraphsage(num_classes,
                                    placeholders,
                                    features,
                                    adj_info,
                                    minibatch.deg,
                                    layer_infos=layer_infos,
                                    aggregator_type="seq",
                                    model_size=FLAGS.model_size,
                                    sigmoid_loss=FLAGS.sigmoid,
                                    identity_dim=FLAGS.identity_dim,
                                    logging=True)

    elif FLAGS.model == 'graphsage_maxpool':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SupervisedGraphsage(num_classes,
                                    placeholders,
                                    features,
                                    adj_info,
                                    minibatch.deg,
                                    layer_infos=layer_infos,
                                    aggregator_type="maxpool",
                                    model_size=FLAGS.model_size,
                                    sigmoid_loss=FLAGS.sigmoid,
                                    identity_dim=FLAGS.identity_dim,
                                    logging=True)

    elif FLAGS.model == 'graphsage_meanpool':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SupervisedGraphsage(num_classes,
                                    placeholders,
                                    features,
                                    adj_info,
                                    minibatch.deg,
                                    layer_infos=layer_infos,
                                    aggregator_type="meanpool",
                                    model_size=FLAGS.model_size,
                                    sigmoid_loss=FLAGS.sigmoid,
                                    identity_dim=FLAGS.identity_dim,
                                    logging=True)

    else:
        raise Exception('Error: model name unrecognized.')

    config = tf.ConfigProto(log_device_placement=FLAGS.log_device_placement)
    config.gpu_options.allow_growth = True
    # config.gpu_options.per_process_gpu_memory_fraction = GPU_MEM_FRACTION
    config.allow_soft_placement = True

    # Initialize session
    sess = tf.Session(config=config)
    merged = tf.summary.merge_all()
    summary_writer = tf.summary.FileWriter(log_dir(), sess.graph)

    # Init variables
    sess.run(tf.global_variables_initializer(),
             feed_dict={adj_info_ph: minibatch.adj})

    # Train model

    total_steps = 0
    avg_time = 0.0
    epoch_val_costs = []

    train_adj_info = tf.assign(adj_info, minibatch.adj)
    val_adj_info = tf.assign(adj_info, minibatch.test_adj)
    for epoch in range(FLAGS.epochs):
        minibatch.shuffle()

        iter = 0
        print('Epoch: %04d' % (epoch + 1))
        epoch_val_costs.append(0)
        while not minibatch.end():
            # Construct feed dictionary
            feed_dict, labels = minibatch.next_minibatch_feed_dict()
            feed_dict.update({placeholders['dropout']: FLAGS.dropout})

            t = time.time()
            # Training step
            outs = sess.run([merged, model.opt_op, model.loss, model.preds],
                            feed_dict=feed_dict)
            train_cost = outs[2]

            if iter % FLAGS.validate_iter == 0:
                # Validation
                sess.run(val_adj_info.op)
                if FLAGS.validate_batch_size == -1:
                    val_cost, val_f1_mic, val_f1_mac, duration = incremental_evaluate(
                        sess, model, minibatch, FLAGS.batch_size)
                else:
                    val_cost, val_f1_mic, val_f1_mac, duration = evaluate(
                        sess, model, minibatch, FLAGS.validate_batch_size)
                sess.run(train_adj_info.op)
                epoch_val_costs[-1] += val_cost

            if total_steps % FLAGS.print_every == 0:
                summary_writer.add_summary(outs[0], total_steps)

            # Print results
            avg_time = (avg_time * total_steps + time.time() -
                        t) / (total_steps + 1)

            if total_steps % FLAGS.print_every == 0:
                train_f1_mic, train_f1_mac = calc_f1(labels, outs[-1])
                print("Iter:", '%04d' % iter, "train_loss=",
                      "{:.5f}".format(train_cost), "train_f1_mic=",
                      "{:.5f}".format(train_f1_mic), "train_f1_mac=",
                      "{:.5f}".format(train_f1_mac), "val_loss=",
                      "{:.5f}".format(val_cost), "val_f1_mic=",
                      "{:.5f}".format(val_f1_mic), "val_f1_mac=",
                      "{:.5f}".format(val_f1_mac), "time=",
                      "{:.5f}".format(avg_time))

            iter += 1
            total_steps += 1

            if total_steps > FLAGS.max_total_steps:
                break

        if total_steps > FLAGS.max_total_steps:
            break

    print("Optimization Finished!")
    sess.run(val_adj_info.op)
    val_cost, val_f1_mic, val_f1_mac, duration = incremental_evaluate(
        sess, model, minibatch, FLAGS.batch_size)
    print("Full validation stats:", "loss=", "{:.5f}".format(val_cost),
          "f1_micro=", "{:.5f}".format(val_f1_mic), "f1_macro=",
          "{:.5f}".format(val_f1_mac), "time=", "{:.5f}".format(duration))
    with open(log_dir() + "val_stats.txt", "w") as fp:
        fp.write(
            "loss={:.5f} f1_micro={:.5f} f1_macro={:.5f} time={:.5f}".format(
                val_cost, val_f1_mic, val_f1_mac, duration))

    print("Writing test set stats to file (don't peak!)")
    val_cost, val_f1_mic, val_f1_mac, duration = incremental_evaluate(
        sess, model, minibatch, FLAGS.batch_size, test=True)
    with open(log_dir() + "test_stats.txt", "w") as fp:
        fp.write("loss={:.5f} f1_micro={:.5f} f1_macro={:.5f}".format(
            val_cost, val_f1_mic, val_f1_mac))

    adj_info_ph = tf.placeholder(tf.int32, shape=minibatch.adj.shape)
    adj_info = tf.Variable(adj_info_ph, trainable=False, name="adj_info")

    # print(sess.run([model.outputs1], feed_dict={adj_info_ph: minibatch.adj})[0].shape)
    # embedding from feed_dict
    # print(sess.run([model.outputs1], feed_dict=feed_dict)[0].shape)
    # minibatch.node_val_feed_dict(14755)[1].shape

    # embedding
    feed_dict_all = dict()
    feed_dict_all[placeholders['batch']] = minibatch.nodes
    feed_dict_all[placeholders['batch_size']] = 14755
    feed_dict_all[placeholders['dropout']] = 0
    feed_dict_all[placeholders['labels']] = minibatch.node_val_feed_dict(
        14755)[1]
    embedding_matrix = sess.run([model.outputs1], feed_dict=feed_dict_all)[0]
    output_folder_path = '/Volumes/DATA/workspace/aus/GraphSAGE/output'
    np.savetxt(output_folder_path + '/labels.txt', minibatch.nodes, fmt='%d')
    np.savetxt(output_folder_path + '/embedding.txt',
               embedding_matrix,
               fmt='%.8f')
    np.savetxt(output_folder_path + '/embedding_projector_format.txt',
               embedding_matrix,
               fmt='%.8f',
               delimiter='\t')
Пример #25
0
def test(train_data, test_data=None):

    G = train_data[0]
    features = train_data[1]
    id_map = train_data[2]
    class_map = train_data[4]
    if isinstance(list(class_map.values())[0], list):
        num_classes = len(list(class_map.values())[0])
    else:
        num_classes = len(set(class_map.values()))

    if not features is None:
        # pad with dummy zero vector
        features = np.vstack([features, np.zeros((features.shape[1], ))])

    context_pairs = train_data[3] if FLAGS.random_context else None
    placeholders = construct_placeholders(num_classes)
    minibatch = NodeMinibatchIterator(G,
                                      id_map,
                                      placeholders,
                                      class_map,
                                      num_classes,
                                      batch_size=FLAGS.batch_size,
                                      max_degree=FLAGS.max_degree,
                                      context_pairs=context_pairs)
    adj_info_ph = tf.placeholder(tf.int32, shape=minibatch.adj.shape)
    adj_info = tf.Variable(adj_info_ph, trainable=False, name="adj_info")

    if FLAGS.model == 'graphsage_mean':
        # Create model
        sampler = UniformNeighborSampler(adj_info)
        if FLAGS.samples_3 != 0:
            layer_infos = [
                SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2),
                SAGEInfo("node", sampler, FLAGS.samples_3, FLAGS.dim_2)
            ]
        elif FLAGS.samples_2 != 0:
            layer_infos = [
                SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
                SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
            ]
        else:
            layer_infos = [
                SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1)
            ]

        model = SupervisedGraphsage(num_classes,
                                    placeholders,
                                    features,
                                    adj_info,
                                    minibatch.deg,
                                    layer_infos,
                                    model_size=FLAGS.model_size,
                                    sigmoid_loss=FLAGS.sigmoid,
                                    identity_dim=FLAGS.identity_dim,
                                    logging=True)
    elif FLAGS.model == 'gcn':
        # Create model
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, 2 * FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, 2 * FLAGS.dim_2)
        ]

        model = SupervisedGraphsage(num_classes,
                                    placeholders,
                                    features,
                                    adj_info,
                                    minibatch.deg,
                                    layer_infos=layer_infos,
                                    aggregator_type="gcn",
                                    model_size=FLAGS.model_size,
                                    concat=False,
                                    sigmoid_loss=FLAGS.sigmoid,
                                    identity_dim=FLAGS.identity_dim,
                                    logging=True)

    elif FLAGS.model == 'graphsage_seq':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SupervisedGraphsage(num_classes,
                                    placeholders,
                                    features,
                                    adj_info,
                                    minibatch.deg,
                                    layer_infos=layer_infos,
                                    aggregator_type="seq",
                                    model_size=FLAGS.model_size,
                                    sigmoid_loss=FLAGS.sigmoid,
                                    identity_dim=FLAGS.identity_dim,
                                    logging=True)

    elif FLAGS.model == 'graphsage_maxpool':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SupervisedGraphsage(num_classes,
                                    placeholders,
                                    features,
                                    adj_info,
                                    minibatch.deg,
                                    layer_infos=layer_infos,
                                    aggregator_type="maxpool",
                                    model_size=FLAGS.model_size,
                                    sigmoid_loss=FLAGS.sigmoid,
                                    identity_dim=FLAGS.identity_dim,
                                    logging=True)

    elif FLAGS.model == 'graphsage_meanpool':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SupervisedGraphsage(num_classes,
                                    placeholders,
                                    features,
                                    adj_info,
                                    minibatch.deg,
                                    layer_infos=layer_infos,
                                    aggregator_type="meanpool",
                                    model_size=FLAGS.model_size,
                                    sigmoid_loss=FLAGS.sigmoid,
                                    identity_dim=FLAGS.identity_dim,
                                    logging=True)

    else:
        raise Exception('Error: model name unrecognized.')

    config = tf.ConfigProto(log_device_placement=FLAGS.log_device_placement)
    config.gpu_options.allow_growth = True
    #config.gpu_options.per_process_gpu_memory_fraction = GPU_MEM_FRACTION
    config.allow_soft_placement = True

    # Initialize session
    sess = tf.Session(config=config)
    merged = tf.summary.merge_all()
    summary_writer = tf.summary.FileWriter(log_dir(), sess.graph)

    # Saved model path
    saver = tf.train.Saver()
    model_path = './model/' + FLAGS.train_prefix.split(
        '/')[-1] + '-' + FLAGS.model_prefix + '/'

    # Init variables
    sess.run(tf.global_variables_initializer(),
             feed_dict={adj_info_ph: minibatch.adj})

    # Restore model
    saver.restore(sess, model_path + 'model')

    total_steps = 0
    avg_time = 0.0

    val_adj_info = tf.assign(adj_info, minibatch.test_adj)

    sess.run(val_adj_info.op)
    print("Writing test set stats to file (don't peak!)")
    val_cost, val_f1_mic, val_f1_mac, duration = incremental_evaluate(
        sess, model, minibatch, FLAGS.batch_size, test=True)
    with open(log_dir() + "test_stats.txt", "w") as fp:
        fp.write("loss={:.5f} f1_micro={:.5f} f1_macro={:.5f}".format(
            val_cost, val_f1_mic, val_f1_mac))
Пример #26
0
def train(train_data, test_data=None):
    G = train_data[0]
    features = train_data[1]
    id_map = train_data[2]

    if not features is None:
        # pad with dummy zero vector
        features = np.vstack([features, np.zeros((features.shape[1], ))])

    context_pairs = train_data[3] if FLAGS.random_context else None
    placeholders = construct_placeholders()
    minibatch = EdgeMinibatchIterator(G,
                                      id_map,
                                      placeholders,
                                      batch_size=FLAGS.batch_size,
                                      max_degree=FLAGS.max_degree,
                                      num_neg_samples=FLAGS.neg_sample_size,
                                      context_pairs=context_pairs)
    adj_info_ph = tf.placeholder(tf.int32, shape=minibatch.adj.shape)
    adj_info = tf.Variable(adj_info_ph, trainable=False, name="adj_info")

    if FLAGS.model == 'graphsage_mean':
        # Create model
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   logging=True)
    elif FLAGS.model == 'gcn':
        # Create model
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, 2 * FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, 2 * FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   aggregator_type="gcn",
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   concat=False,
                                   logging=True)

    elif FLAGS.model == 'graphsage_seq':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   identity_dim=FLAGS.identity_dim,
                                   aggregator_type="seq",
                                   model_size=FLAGS.model_size,
                                   logging=True)

    elif FLAGS.model == 'graphsage_maxpool':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   aggregator_type="maxpool",
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   logging=True)
    elif FLAGS.model == 'graphsage_meanpool':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   aggregator_type="meanpool",
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   logging=True)

    elif FLAGS.model == 'n2v':
        model = Node2VecModel(
            placeholders,
            features.shape[0],
            minibatch.deg,
            #2x because graphsage uses concat
            nodevec_dim=2 * FLAGS.dim_1,
            lr=FLAGS.learning_rate)
    else:
        raise Exception('Error: model name unrecognized.')

    config = tf.ConfigProto(log_device_placement=FLAGS.log_device_placement)
    config.gpu_options.allow_growth = True
    #config.gpu_options.per_process_gpu_memory_fraction = GPU_MEM_FRACTION
    config.allow_soft_placement = True

    # default to saving all variables : added by wy
    saver = tf.train.Saver()

    # Initialize session
    sess = tf.Session(config=config)
    merged = tf.summary.merge_all()
    summary_writer = tf.summary.FileWriter(log_dir(), sess.graph)

    # Init variables
    sess.run(tf.global_variables_initializer(),
             feed_dict={adj_info_ph: minibatch.adj})

    # Train model
    if FLAGS.isTrain:
        print("== Training Model ==")
        train_shadow_mrr = None
        shadow_mrr = None

        total_steps = 0
        avg_time = 0.0
        epoch_val_costs = []

        train_adj_info = tf.assign(adj_info, minibatch.adj)
        val_adj_info = tf.assign(adj_info, minibatch.test_adj)
        for epoch in range(FLAGS.epochs):
            minibatch.shuffle()

            iter = 0
            print('Epoch: %04d' % (epoch + 1))
            epoch_val_costs.append(0)
            while not minibatch.end():
                # Construct feed dictionary
                feed_dict = minibatch.next_minibatch_feed_dict()
                feed_dict.update({placeholders['dropout']: FLAGS.dropout})

                t = time.time()
                # Training step
                outs = sess.run([
                    merged, model.opt_op, model.loss, model.ranks,
                    model.aff_all, model.mrr, model.outputs1
                ],
                                feed_dict=feed_dict)
                train_cost = outs[2]
                train_mrr = outs[5]
                if train_shadow_mrr is None:
                    train_shadow_mrr = train_mrr  #
                else:
                    train_shadow_mrr -= (1 - 0.99) * (train_shadow_mrr -
                                                      train_mrr)

                if iter % FLAGS.validate_iter == 0:
                    # Validation
                    sess.run(val_adj_info.op)
                    val_cost, ranks, val_mrr, duration = evaluate(
                        sess, model, minibatch, size=FLAGS.validate_batch_size)
                    sess.run(train_adj_info.op)
                    epoch_val_costs[-1] += val_cost
                if shadow_mrr is None:
                    shadow_mrr = val_mrr
                else:
                    shadow_mrr -= (1 - 0.99) * (shadow_mrr - val_mrr)

                if total_steps % FLAGS.print_every == 0:
                    summary_writer.add_summary(outs[0], total_steps)

                # Print results
                avg_time = (avg_time * total_steps + time.time() -
                            t) / (total_steps + 1)

                if total_steps % FLAGS.print_every == 0:
                    print(
                        "Iter:",
                        '%04d' % iter,
                        "train_loss=",
                        "{:.5f}".format(train_cost),
                        "train_mrr=",
                        "{:.5f}".format(train_mrr),
                        "train_mrr_ema=",
                        "{:.5f}".format(
                            train_shadow_mrr),  # exponential moving average
                        "val_loss=",
                        "{:.5f}".format(val_cost),
                        "val_mrr=",
                        "{:.5f}".format(val_mrr),
                        "val_mrr_ema=",
                        "{:.5f}".format(
                            shadow_mrr),  # exponential moving average
                        "time=",
                        "{:.5f}".format(avg_time))

                    iter += 1
                total_steps += 1

                if total_steps > FLAGS.max_total_steps:
                    break

            if total_steps > FLAGS.max_total_steps:
                break
        print("Optimization Finished!")
        # save model
        saver.save(sess, log_dir() + 'model.ckpt')
        print("save model succeed!")
    else:
        print("==Testing Model==")
        ckpt = tf.train.get_checkpoint_state(log_dir())
        print(log_dir())
        print(ckpt)
        print(ckpt.model_ckeckpoint_path)
        raw_input()
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_ckeckpoint_path)
        else:
            pass

    raw_input()
    if FLAGS.save_embeddings:
        sess.run(val_adj_info.op)

        save_val_embeddings(sess, model, minibatch, FLAGS.validate_batch_size,
                            log_dir())

        if FLAGS.model == "n2v":
            # stopping the gradient for the already trained nodes
            train_ids = tf.constant(
                [[id_map[n]] for n in G.nodes_iter()
                 if not G.node[n]['val'] and not G.node[n]['test']],
                dtype=tf.int32)
            test_ids = tf.constant([[id_map[n]] for n in G.nodes_iter()
                                    if G.node[n]['val'] or G.node[n]['test']],
                                   dtype=tf.int32)
            update_nodes = tf.nn.embedding_lookup(model.context_embeds,
                                                  tf.squeeze(test_ids))
            no_update_nodes = tf.nn.embedding_lookup(model.context_embeds,
                                                     tf.squeeze(train_ids))
            update_nodes = tf.scatter_nd(test_ids, update_nodes,
                                         tf.shape(model.context_embeds))
            no_update_nodes = tf.stop_gradient(
                tf.scatter_nd(train_ids, no_update_nodes,
                              tf.shape(model.context_embeds)))
            model.context_embeds = update_nodes + no_update_nodes
            sess.run(model.context_embeds)

            # run random walks
            from graphsage.utils import run_random_walks
            nodes = [
                n for n in G.nodes_iter()
                if G.node[n]["val"] or G.node[n]["test"]
            ]
            start_time = time.time()
            pairs = run_random_walks(G, nodes, num_walks=50)
            walk_time = time.time() - start_time

            test_minibatch = EdgeMinibatchIterator(
                G,
                id_map,
                placeholders,
                batch_size=FLAGS.batch_size,
                max_degree=FLAGS.max_degree,
                num_neg_samples=FLAGS.neg_sample_size,
                context_pairs=pairs,
                n2v_retrain=True,
                fixed_n2v=True)

            start_time = time.time()
            print("Doing test training for n2v.")
            test_steps = 0
            for epoch in range(FLAGS.n2v_test_epochs):
                test_minibatch.shuffle()
                while not test_minibatch.end():
                    feed_dict = test_minibatch.next_minibatch_feed_dict()
                    feed_dict.update({placeholders['dropout']: FLAGS.dropout})
                    outs = sess.run([
                        model.opt_op, model.loss, model.ranks, model.aff_all,
                        model.mrr, model.outputs1
                    ],
                                    feed_dict=feed_dict)
                    if test_steps % FLAGS.print_every == 0:
                        print("Iter:", '%04d' % test_steps, "train_loss=",
                              "{:.5f}".format(outs[1]), "train_mrr=",
                              "{:.5f}".format(outs[-2]))
                    test_steps += 1
            train_time = time.time() - start_time
            save_val_embeddings(sess,
                                model,
                                minibatch,
                                FLAGS.validate_batch_size,
                                log_dir(),
                                mod="-test")
            print("Total time: ", train_time + walk_time)
            print("Walk time: ", walk_time)
            print("Train time: ", train_time)
def train(train_data, test_data=None):
    G = train_data[0]
    features = train_data[1]
    id_map = train_data[2]
    prob_matrix = train_data[5]

    if not features is None:
        # pad with dummy zero vector
        features = np.vstack([features, np.zeros((features.shape[1], ))])

    context_pairs = train_data[3] if FLAGS.random_context else None
    placeholders = construct_placeholders()
    minibatch = EdgeMinibatchIterator(G,
                                      id_map,
                                      placeholders,
                                      prob_matrix,
                                      batch_size=FLAGS.batch_size,
                                      max_degree=FLAGS.max_degree,
                                      num_neg_samples=FLAGS.neg_sample_size,
                                      context_pairs=context_pairs)
    adj_info_ph = tf.placeholder(tf.int32, shape=minibatch.adj.shape)
    adj_info = tf.Variable(adj_info_ph, trainable=False, name="adj_info")

    if FLAGS.model == 'graphsage_mean':
        # Create model
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   logging=True)
    elif FLAGS.model == 'gcn':
        # Create model
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, 2 * FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, 2 * FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   aggregator_type="gcn",
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   concat=False,
                                   logging=True)

    elif FLAGS.model == 'graphsage_seq':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   identity_dim=FLAGS.identity_dim,
                                   aggregator_type="seq",
                                   model_size=FLAGS.model_size,
                                   logging=True)

    elif FLAGS.model == 'graphsage_maxpool':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   aggregator_type="maxpool",
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   logging=True)
    elif FLAGS.model == 'graphsage_meanpool':
        sampler = UniformNeighborSampler(adj_info)
        layer_infos = [
            SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
            SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)
        ]

        model = SampleAndAggregate(placeholders,
                                   features,
                                   adj_info,
                                   minibatch.deg,
                                   layer_infos=layer_infos,
                                   aggregator_type="meanpool",
                                   model_size=FLAGS.model_size,
                                   identity_dim=FLAGS.identity_dim,
                                   logging=True)

    elif FLAGS.model == 'n2v':
        model = Node2VecModel(
            placeholders,
            features.shape[0],
            minibatch.deg,
            #2x because graphsage uses concat
            nodevec_dim=2 * FLAGS.dim_1,
            lr=FLAGS.learning_rate)
    else:
        raise Exception('Error: model name unrecognized.')

    config = tf.ConfigProto(log_device_placement=FLAGS.log_device_placement)
    config.gpu_options.allow_growth = True
    #config.gpu_options.per_process_gpu_memory_fraction = GPU_MEM_FRACTION
    config.allow_soft_placement = True

    # Initialize session
    sess = tf.Session(config=config)
    merged = tf.summary.merge_all()
    summary_writer = tf.summary.FileWriter(log_dir(), sess.graph)

    # Init variables
    sess.run(tf.global_variables_initializer(),
             feed_dict={adj_info_ph: minibatch.adj})

    # Train model

    train_shadow_mrr = None
    shadow_mrr = None

    total_steps = 0
    avg_time = 0.0
    epoch_val_costs = []

    train_adj_info = tf.assign(adj_info, minibatch.adj)
    val_adj_info = tf.assign(adj_info, minibatch.test_adj)

    min_train_loss = 0.0
    fp = open(log_dir() + 'epoch_train_loss.txt', 'w')
    loss_list = []
    epoch_list = []

    for epoch in range(FLAGS.epochs):
        minibatch.shuffle()

        epoch_train_loss = 0.0

        iter = 0
        print('Epoch: %04d' % (epoch + 1))
        epoch_val_costs.append(0)
        while not minibatch.end():
            # Construct feed dictionary
            feed_dict = minibatch.next_minibatch_feed_dict(layer_infos)
            feed_dict.update({placeholders['dropout']: FLAGS.dropout})

            t = time.time()
            # Training step
            outs = sess.run([
                merged, model.opt_op, model.loss, model.ranks, model.aff_all,
                model.mrr, model.outputs1
            ],
                            feed_dict=feed_dict)
            train_cost = outs[2]
            train_mrr = outs[5]

            epoch_train_loss += train_cost

            if train_shadow_mrr is None:
                train_shadow_mrr = train_mrr  #
            else:
                train_shadow_mrr -= (1 - 0.99) * (train_shadow_mrr - train_mrr)

            if total_steps % FLAGS.print_every == 0:
                summary_writer.add_summary(outs[0], total_steps)

            # Print results
            avg_time = (avg_time * total_steps + time.time() -
                        t) / (total_steps + 1)

            if total_steps % FLAGS.print_every == 0:
                print(
                    "Iter:",
                    '%04d' % iter,
                    "train_loss=",
                    "{:.5f}".format(train_cost),
                    "train_mrr=",
                    "{:.5f}".format(train_mrr),
                    "train_mrr_ema=",
                    "{:.5f}".format(
                        train_shadow_mrr),  # exponential moving average
                    "time=",
                    "{:.5f}".format(avg_time))

            iter += 1
            total_steps += 1

            if total_steps > FLAGS.max_total_steps:
                break

        if total_steps > FLAGS.max_total_steps:
            break

        epoch_train_loss = epoch_train_loss / iter

        print('epoch: ' + str(epoch) + '\t' + 'train_loss: ' +
              str(epoch_train_loss) + '\n')
        fp.write('epoch: ' + str(epoch) + '\t' + 'train_loss: ' +
                 str(epoch_train_loss) + '\n')
        loss_list.append(epoch_train_loss)
        epoch_list.append(epoch)

        if epoch == 0:
            min_train_loss = epoch_train_loss
            sess.run(val_adj_info.op)
            node_embeddings, node_list = save_val_embeddings(
                sess, model, minibatch, FLAGS.validate_batch_size, log_dir(),
                layer_infos)
        elif epoch_train_loss < min_train_loss:
            min_train_loss = epoch_train_loss
            sess.run(val_adj_info.op)
            node_embeddings, node_list = save_val_embeddings(
                sess, model, minibatch, FLAGS.validate_batch_size, log_dir(),
                layer_infos)

    print("Optimization Finished!")

    fp.close()

    np.save(log_dir() + "val.npy", node_embeddings)
    with open(log_dir() + "val.txt", "w") as fp:
        fp.write("\n".join(map(str, node_list)))

    plt.plot(epoch_list, loss_list)
    plt.savefig(log_dir() + 'loss_trend.png')
    plt.clf()

    ## t-SNE plot
    ## creating t-SNE embedding feature
    #tsne_feat=TSNE(n_components=2).fit_transform(node_embeddings)
    tsne_feat = node_embeddings

    ## plotting t-SNE embedding
    tsne_df = pd.DataFrame(columns=['tsne0', 'tsne1'])
    tsne_df['tsne0'] = tsne_feat[:, 0]
    tsne_df['tsne1'] = tsne_feat[:, 1]
    plt.figure(figsize=(16, 10))
    sns.scatterplot(x="tsne0",
                    y="tsne1",
                    hue=node_list,
                    palette=sns.color_palette("hls", len(node_list)),
                    data=tsne_df,
                    legend=False,
                    alpha=0.3)
    for i, txt in enumerate(node_list):
        plt.annotate(txt, (tsne_df['tsne0'][i], tsne_df['tsne1'][i]))
    plt.savefig(log_dir() + 't-SNE_plot_node_embeddings.png')
    plt.show()

    if FLAGS.save_embeddings:
        sess.run(val_adj_info.op)

        save_val_embeddings(sess, model, minibatch, FLAGS.validate_batch_size,
                            log_dir())

        if FLAGS.model == "n2v":
            # stopping the gradient for the already trained nodes
            train_ids = tf.constant(
                [[id_map[n]] for n in G.nodes_iter()
                 if not G.node[n]['val'] and not G.node[n]['test']],
                dtype=tf.int32)
            test_ids = tf.constant([[id_map[n]] for n in G.nodes_iter()
                                    if G.node[n]['val'] or G.node[n]['test']],
                                   dtype=tf.int32)
            update_nodes = tf.nn.embedding_lookup(model.context_embeds,
                                                  tf.squeeze(test_ids))
            no_update_nodes = tf.nn.embedding_lookup(model.context_embeds,
                                                     tf.squeeze(train_ids))
            update_nodes = tf.scatter_nd(test_ids, update_nodes,
                                         tf.shape(model.context_embeds))
            no_update_nodes = tf.stop_gradient(
                tf.scatter_nd(train_ids, no_update_nodes,
                              tf.shape(model.context_embeds)))
            model.context_embeds = update_nodes + no_update_nodes
            sess.run(model.context_embeds)

            # run random walks
            from graphsage.utils import run_random_walks
            nodes = [
                n for n in G.nodes_iter()
                if G.node[n]["val"] or G.node[n]["test"]
            ]
            start_time = time.time()
            pairs = run_random_walks(G, nodes, num_walks=50)
            walk_time = time.time() - start_time

            test_minibatch = EdgeMinibatchIterator(
                G,
                id_map,
                placeholders,
                batch_size=FLAGS.batch_size,
                max_degree=FLAGS.max_degree,
                num_neg_samples=FLAGS.neg_sample_size,
                context_pairs=pairs,
                n2v_retrain=True,
                fixed_n2v=True)

            start_time = time.time()
            print("Doing test training for n2v.")
            test_steps = 0
            for epoch in range(FLAGS.n2v_test_epochs):
                test_minibatch.shuffle()
                while not test_minibatch.end():
                    feed_dict = test_minibatch.next_minibatch_feed_dict(
                        layer_infos)
                    feed_dict.update({placeholders['dropout']: FLAGS.dropout})
                    outs = sess.run([
                        model.opt_op, model.loss, model.ranks, model.aff_all,
                        model.mrr, model.outputs1
                    ],
                                    feed_dict=feed_dict)
                    if test_steps % FLAGS.print_every == 0:
                        print("Iter:", '%04d' % test_steps, "train_loss=",
                              "{:.5f}".format(outs[1]), "train_mrr=",
                              "{:.5f}".format(outs[-2]))
                    test_steps += 1
            train_time = time.time() - start_time
            save_val_embeddings(sess,
                                model,
                                minibatch,
                                FLAGS.validate_batch_size,
                                log_dir(),
                                mod="-test")
            print("Total time: ", train_time + walk_time)
            print("Walk time: ", walk_time)
            print("Train time: ", train_time)