def predict(train_data, test_data=None): num_k = FLAGS.num_k G = train_data[0] #features = train_data[1] #features_store = np.copy(features) id_map = train_data[1] class_map = train_data[3] if isinstance(list(class_map.values())[0], list): num_classes = len(list(class_map.values())[0]) else: num_classes = len(set(class_map.values())) # if not features is None: # # pad with dummy zero vector # features = np.vstack([features, np.zeros((features.shape[1],))]) context_pairs = train_data[3] if FLAGS.random_context else None placeholders = construct_placeholders(num_classes) minibatch = NodeMinibatchIterator( G, id_map, placeholders, class_map, num_classes, batch_size=FLAGS.batch_size, max_degree=FLAGS.max_degree, context_pairs=context_pairs, budget=num_k, bud_mul_fac=FLAGS.bud_mul_fac, mode="test", prefix=FLAGS.train_prefix, sampling_freq=FLAGS.neighborhood_sampling) features = minibatch.features time_minibatch = minibatch.time_sampling_plus_norm # time_minibatch_end - time_minibatch_start print("creatred features and sampling done") print("minibatch time", time_minibatch) adj_info_tf_time_beg = time.time() adj_info_ph = tf.placeholder(tf.int32, shape=minibatch.adj.shape) adj_info = tf.Variable(adj_info_ph, trainable=False, name="adj_info") adj_info_tf_time_end = time.time() adj_info_tf_time = adj_info_tf_time_end - adj_info_tf_time_beg print("adj info time", adj_info_tf_time) if FLAGS.model == 'graphsage_mean': # Create model sampler = UniformNeighborSampler(adj_info) if FLAGS.samples_3 != 0: layer_infos = [ SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1), SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2), SAGEInfo("node", sampler, FLAGS.samples_3, FLAGS.dim_2) ] elif FLAGS.samples_2 != 0: layer_infos = [ SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1), SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2) ] else: layer_infos = [ SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1) ] model = SupervisedGraphsage(num_classes, placeholders, features, adj_info, minibatch.deg, layer_infos, model_size=FLAGS.model_size, sigmoid_loss=FLAGS.sigmoid, identity_dim=FLAGS.identity_dim, logging=True) elif FLAGS.model == 'gcn': # Create model sampler = UniformNeighborSampler(adj_info) layer_infos = [ SAGEInfo("node", sampler, FLAGS.samples_1, 2 * FLAGS.dim_1), SAGEInfo("node", sampler, FLAGS.samples_2, 2 * FLAGS.dim_2) ] model = SupervisedGraphsage(num_classes, placeholders, features, adj_info, minibatch.deg, layer_infos=layer_infos, aggregator_type="gcn", model_size=FLAGS.model_size, concat=False, sigmoid_loss=FLAGS.sigmoid, identity_dim=FLAGS.identity_dim, logging=True) elif FLAGS.model == 'graphsage_seq': sampler = UniformNeighborSampler(adj_info) layer_infos = [ SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1), SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2) ] model = SupervisedGraphsage(num_classes, placeholders, features, adj_info, minibatch.deg, layer_infos=layer_infos, aggregator_type="seq", model_size=FLAGS.model_size, sigmoid_loss=FLAGS.sigmoid, identity_dim=FLAGS.identity_dim, logging=True) elif FLAGS.model == 'graphsage_maxpool': sampler = UniformNeighborSampler(adj_info) layer_infos = [ SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1), SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2) ] model = SupervisedGraphsage(num_classes, placeholders, features, adj_info, minibatch.deg, layer_infos=layer_infos, aggregator_type="maxpool", model_size=FLAGS.model_size, sigmoid_loss=FLAGS.sigmoid, identity_dim=FLAGS.identity_dim, logging=True) elif FLAGS.model == 'graphsage_meanpool': sampler = UniformNeighborSampler(adj_info) layer_infos = [ SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1), SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2) ] model = SupervisedGraphsage(num_classes, placeholders, features, adj_info, minibatch.deg, layer_infos=layer_infos, aggregator_type="meanpool", model_size=FLAGS.model_size, sigmoid_loss=FLAGS.sigmoid, identity_dim=FLAGS.identity_dim, logging=True) else: raise Exception('Error: model name unrecognized.') config = tf.ConfigProto(log_device_placement=FLAGS.log_device_placement) config.gpu_options.allow_growth = True #config.gpu_options.per_process_gpu_memory_fraction = GPU_MEM_FRACTION config.allow_soft_placement = True # Initialize session sess = tf.Session(config=config) merged = tf.summary.merge_all() summary_writer = tf.summary.FileWriter(log_dir(), sess.graph) # Init variables adj_info_tf_init_time_beg = time.time() sess.run(tf.global_variables_initializer(), feed_dict={adj_info_ph: minibatch.adj}) adj_info_tf_init_time_end = time.time() adj_info_tf_init_time = adj_info_tf_init_time_end - adj_info_tf_init_time_beg print("adj info time init", adj_info_tf_init_time) # train_adj_info = tf.assign(adj_info, minibatch.adj) # val_adj_info = tf.assign(adj_info, minibatch.test_adj) # Load trained model var_to_save = [] for var in tf.trainable_variables(): var_to_save.append(var) saver = tf.train.Saver(var_to_save) print("Trained Model Loading!") saver.restore(sess, "ICYoutubesupervisedTrainedModel_MC_marginal/model.ckpt") print("Trained Model Loaded!") sup_gs_pred_start_time = time.time() print("Predicting the classes of all the data set") predict_prob, pred_classes, embeddings = incremental_predict( sess, model, minibatch, FLAGS.batch_size) print("Predicted the classes of all the data set") sup_gs_pred_end_time = time.time() print("Saving the Predicted Output") sup_gs_prediction_time = sup_gs_pred_end_time - sup_gs_pred_start_time print("embed time ", sup_gs_prediction_time) # to output # print("predict_prob", predict_prob[143]) # print("pred_classes", pred_classes[143]) # print("embeddings", embeddings[143]) print("Saved the Predicted Output") active_prob_beg_time = time.time() active_one_prob = {} active_one_prob_dict = {} for index in range(0, len(pred_classes)): active_one_prob[index] = predict_prob[index][0] # print(index,pred_classes[index],predict_prob[index]) for index in range(0, len(pred_classes)): active_one_prob_dict[minibatch.dict_map_couter_to_real_node[ minibatch.top_degree_nodes[index]]] = predict_prob[index][0] # bottom_nodes, top_nodes = bipartite.sets(G) #bottom_nodes = list(bottom_nodes) total_top_ten_percent = len( active_one_prob ) # int(num_k*FLAGS.bud_mul_fac)#int(0.25*len(bottom_nodes)) # if total_top_ten_percent>500: # total_top_ten_percent = 500 sorted_dict = sorted(active_one_prob_dict.items(), key=operator.itemgetter(1), reverse=True) top_ten_percent = [] count_solution = 0 y = 0 while count_solution < total_top_ten_percent: #if sorted_dict[y][0] in bottom_nodes: top_ten_percent.append(sorted_dict[y][0]) count_solution = count_solution + 1 y = y + 1 active_prob_end_time = time.time() active_prob_time = active_prob_end_time - active_prob_beg_time print("actve prob time ", active_prob_time) result_top_percent = FLAGS.train_prefix + "_top_ten_percent.txt" + "_" + str( num_k) + "_nbs_" + str(FLAGS.neighborhood_sampling) file_handle2 = open(result_top_percent, "w") print('*******************', len(top_ten_percent)) dict_node_scores = {} for ind in top_ten_percent: file_handle2.write(str(ind)) file_handle2.write(" ") dict_node_scores[ind] = active_one_prob_dict[ind] file_handle2.close() dict_node_scores_file_name = FLAGS.train_prefix + "_node_scores_supgs" + "_" + str( num_k) + "_nbs_" + str(FLAGS.neighborhood_sampling) import pickle pickle_start_time = time.time() with open(dict_node_scores_file_name + '.pickle', 'wb') as handle: pickle.dump(dict_node_scores, handle, protocol=pickle.HIGHEST_PROTOCOL) pickle_end_time = time.time() pickle_time = pickle_end_time - pickle_start_time print("pickle scores time ", pickle_time) #print(dict_node_scores) result_top_percent = FLAGS.train_prefix + "_top_ten_percent_analyse.txt" + "_" + str( num_k) + "_nbs_" + str(FLAGS.neighborhood_sampling) file_handle2 = open(result_top_percent, "w") print('*******************', len(top_ten_percent)) print('******************* Writing top to file') graph_degree = G.out_degree(weight='weight') for ind in top_ten_percent: #print(ind) file_handle2.write( str(ind) + " " + str(graph_degree[ind]) + " " + str(active_one_prob_dict[ind])) file_handle2.write(" \n") file_handle2.close() print('******************* Written top to file') top_30 = [] count_solution = 0 y = 0 while count_solution < num_k: # if sorted_dict[y][0] in bottom_nodes: top_30.append(sorted_dict[y][0]) count_solution = count_solution + 1 y = y + 1 result_file_name = FLAGS.train_prefix + "_sup_GS_sol.txt" + "_" + str( num_k) + "_nbs_" + str(FLAGS.neighborhood_sampling) file_handle = open(result_file_name, "w") # file_handle.write(str(num_k)) # file_handle.write("\n") for ind in top_30: file_handle.write(str(ind)) file_handle.write("\n") file_handle.close() from sklearn.preprocessing import StandardScaler # scaler = StandardScaler() # scaler.fit(embeddings) # embeddings = scaler.transform(embeddings) # embeddings = np.array(embeddings) # embeddings = np.hstack([embeddings, features_store]) print('Final Embeddings shape = ', embeddings.shape) embedding_file_name = FLAGS.train_prefix + "_embeddings.npy" + "_" + str( num_k) + "_nbs_" + str(FLAGS.neighborhood_sampling) # np.save(embedding_file_name,embeddings) dict_embeddings_top_for_rl_without_rw = {} for index, node_id in enumerate(top_ten_percent): #print("map", node_id, minibatch.top_degree_nodes[index], index) embed_sup_gs = embeddings[index] dict_embeddings_top_for_rl_without_rw[node_id] = embed_sup_gs # print("index, nodeid ", index, node_id) import pickle with open(embedding_file_name + '.pickle', 'wb') as handle: pickle.dump(dict_embeddings_top_for_rl_without_rw, handle, protocol=pickle.HIGHEST_PROTOCOL) total_time_for_rl_prep = adj_info_tf_time + time_minibatch + sup_gs_prediction_time + active_prob_time + adj_info_tf_init_time #time_rl_prep_file_name=FLAGS.train_prefix + "_num_k_" + str(FLAGS.num_k) + "_time.txt"+"_"+str(num_k) time_rl_prep_file_name = FLAGS.train_prefix + "_num_k_" + str( FLAGS.num_k) + "_time.txt" + "_" + str(num_k) + "_nbs_" + str( FLAGS.neighborhood_sampling) print(time_rl_prep_file_name) time_file = open(time_rl_prep_file_name, 'w') time_file.write("RL_PREP_TIME_" + str(total_time_for_rl_prep) + '\n') # # reward_file_name = FLAGS.train_prefix + ".sup_GS_reward" # reward = evaluaterew.evaluate(G,top_30) # file_handle3 = open(reward_file_name,"w") # file_handle3.write(str(reward)) # file_handle3.close() print(" time rl prepare", total_time_for_rl_prep)
def train(train_data, test_data=None): G = train_data[0] features = train_data[1] id_map = train_data[2] class_map = train_data[4] if isinstance(list(class_map.values())[0], list): num_classes = len(list(class_map.values())[0]) else: num_classes = len(set(class_map.values())) if not features is None: # pad with dummy zero vector features = np.vstack([features, np.zeros((features.shape[1], ))]) context_pairs = train_data[3] if FLAGS.random_context else None placeholders = construct_placeholders(num_classes) minibatch = NodeMinibatchIterator(G, id_map, placeholders, class_map, num_classes, batch_size=FLAGS.batch_size, max_degree=FLAGS.max_degree, context_pairs=context_pairs) adj_info_ph = tf.placeholder(tf.int32, shape=minibatch.adj.shape) adj_info = tf.Variable(adj_info_ph, trainable=False, name="adj_info") if FLAGS.model == 'graphsage_mean': # Create model sampler = UniformNeighborSampler(adj_info) if FLAGS.samples_3 != 0: layer_infos = [ SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1), SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2), SAGEInfo("node", sampler, FLAGS.samples_3, FLAGS.dim_2) ] elif FLAGS.samples_2 != 0: layer_infos = [ SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1), SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2) ] else: layer_infos = [ SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1) ] model = SupervisedGraphsage(num_classes, placeholders, features, adj_info, minibatch.deg, layer_infos, model_size=FLAGS.model_size, sigmoid_loss=FLAGS.sigmoid, identity_dim=FLAGS.identity_dim, logging=True) elif FLAGS.model == 'gcn': # Create model sampler = UniformNeighborSampler(adj_info) layer_infos = [ SAGEInfo("node", sampler, FLAGS.samples_1, 2 * FLAGS.dim_1), SAGEInfo("node", sampler, FLAGS.samples_2, 2 * FLAGS.dim_2) ] model = SupervisedGraphsage(num_classes, placeholders, features, adj_info, minibatch.deg, layer_infos=layer_infos, aggregator_type="gcn", model_size=FLAGS.model_size, concat=False, sigmoid_loss=FLAGS.sigmoid, identity_dim=FLAGS.identity_dim, logging=True) elif FLAGS.model == 'graphsage_seq': sampler = UniformNeighborSampler(adj_info) layer_infos = [ SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1), SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2) ] model = SupervisedGraphsage(num_classes, placeholders, features, adj_info, minibatch.deg, layer_infos=layer_infos, aggregator_type="seq", model_size=FLAGS.model_size, sigmoid_loss=FLAGS.sigmoid, identity_dim=FLAGS.identity_dim, logging=True) elif FLAGS.model == 'graphsage_maxpool': sampler = UniformNeighborSampler(adj_info) layer_infos = [ SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1), SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2) ] model = SupervisedGraphsage(num_classes, placeholders, features, adj_info, minibatch.deg, layer_infos=layer_infos, aggregator_type="maxpool", model_size=FLAGS.model_size, sigmoid_loss=FLAGS.sigmoid, identity_dim=FLAGS.identity_dim, logging=True) elif FLAGS.model == 'graphsage_meanpool': sampler = UniformNeighborSampler(adj_info) layer_infos = [ SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1), SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2) ] model = SupervisedGraphsage(num_classes, placeholders, features, adj_info, minibatch.deg, layer_infos=layer_infos, aggregator_type="meanpool", model_size=FLAGS.model_size, sigmoid_loss=FLAGS.sigmoid, identity_dim=FLAGS.identity_dim, logging=True) else: raise Exception('Error: model name unrecognized.') config = tf.ConfigProto(log_device_placement=FLAGS.log_device_placement) config.gpu_options.allow_growth = True #config.gpu_options.per_process_gpu_memory_fraction = GPU_MEM_FRACTION config.allow_soft_placement = True # Initialize session sess = tf.Session(config=config) merged = tf.summary.merge_all() summary_writer = tf.summary.FileWriter(log_dir(), sess.graph) # Init variables sess.run(tf.global_variables_initializer(), feed_dict={adj_info_ph: minibatch.adj}) # Train model total_steps = 0 avg_time = 0.0 epoch_val_costs = [] train_adj_info = tf.assign(adj_info, minibatch.adj) val_adj_info = tf.assign(adj_info, minibatch.test_adj) for epoch in range(FLAGS.epochs): minibatch.shuffle() iter = 0 print('Epoch: %04d' % (epoch + 1)) epoch_val_costs.append(0) while not minibatch.end(): # Construct feed dictionary feed_dict, labels = minibatch.next_minibatch_feed_dict() feed_dict.update({placeholders['dropout']: FLAGS.dropout}) t = time.time() # Training step outs = sess.run([merged, model.opt_op, model.loss, model.preds], feed_dict=feed_dict) train_cost = outs[2] if iter % FLAGS.validate_iter == 0: # Validation sess.run(val_adj_info.op) if FLAGS.validate_batch_size == -1: val_cost, val_f1_mic, val_f1_mac, duration = incremental_evaluate( sess, model, minibatch, FLAGS.batch_size) else: val_cost, val_f1_mic, val_f1_mac, duration = evaluate( sess, model, minibatch, FLAGS.validate_batch_size) sess.run(train_adj_info.op) epoch_val_costs[-1] += val_cost if total_steps % FLAGS.print_every == 0: summary_writer.add_summary(outs[0], total_steps) # Print results avg_time = (avg_time * total_steps + time.time() - t) / (total_steps + 1) if total_steps % FLAGS.print_every == 0: train_f1_mic, train_f1_mac = calc_f1(labels, outs[-1]) print("Iter:", '%04d' % iter, "train_loss=", "{:.5f}".format(train_cost), "train_f1_mic=", "{:.5f}".format(train_f1_mic), "train_f1_mac=", "{:.5f}".format(train_f1_mac), "val_loss=", "{:.5f}".format(val_cost), "val_f1_mic=", "{:.5f}".format(val_f1_mic), "val_f1_mac=", "{:.5f}".format(val_f1_mac), "time=", "{:.5f}".format(avg_time)) iter += 1 total_steps += 1 if total_steps > FLAGS.max_total_steps: break if total_steps > FLAGS.max_total_steps: break print("Optimization Finished!") sess.run(val_adj_info.op) val_cost, val_f1_mic, val_f1_mac, duration = incremental_evaluate( sess, model, minibatch, FLAGS.batch_size) print("Full validation stats:", "loss=", "{:.5f}".format(val_cost), "f1_micro=", "{:.5f}".format(val_f1_mic), "f1_macro=", "{:.5f}".format(val_f1_mac), "time=", "{:.5f}".format(duration)) with open(log_dir() + "val_stats.txt", "w") as fp: fp.write( "loss={:.5f} f1_micro={:.5f} f1_macro={:.5f} time={:.5f}".format( val_cost, val_f1_mic, val_f1_mac, duration)) print("Writing test set stats to file (don't peak!)") val_cost, val_f1_mic, val_f1_mac, duration = incremental_evaluate( sess, model, minibatch, FLAGS.batch_size, test=True) with open(log_dir() + "test_stats.txt", "w") as fp: fp.write("loss={:.5f} f1_micro={:.5f} f1_macro={:.5f}".format( val_cost, val_f1_mic, val_f1_mac))
SAGEInfo("node", sampler, FLAGS.samples_3, FLAGS.dim_2) ] elif FLAGS.samples_2 != 0: layer_infos = [ SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1), SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2) ] else: layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1)] model = SupervisedGraphsage(num_classes, placeholders, features, adj_info, minibatch.deg, layer_infos, f_alpha=FLAGS.focal_alpha, model_size=FLAGS.model_size, sigmoid_loss=FLAGS.loss, identity_dim=FLAGS.identity_dim, logging=True) elif FLAGS.model == 'gcn': # Create model sampler = UniformNeighborSampler(adj_info) layer_infos = [ SAGEInfo("node", sampler, FLAGS.samples_1, 2 * FLAGS.dim_1), SAGEInfo("node", sampler, FLAGS.samples_2, 2 * FLAGS.dim_2) ] model = SupervisedGraphsage(num_classes,
def train(train_data, test_data=None): G = train_data[0] # G 是一个Networkx里的对象,这几个都是经过load_data()处理过的 features = train_data[1] id_map = train_data[2] class_map = train_data[4] class_map2 = train_data[5] class_map3 = train_data[6] #class_map = class_map hierarchy = FLAGS.hierarchy ko_threshold = FLAGS.ko_threshold ko_threshold2 = FLAGS.ko_threshold2 if features is not None: # pad with dummy zero vector features = np.vstack([features, np.zeros((features.shape[1], ))]) features = tf.cast(features, tf.float32) for hi_num in range(hierarchy): if hi_num == 0: class_map_ko_0 = construct_class_numpy(class_map) class_map_ko = construct_class_numpy(class_map) a = class_map_ko.sum(axis=0) count = 0 list_del = [] for i in a: if i < ko_threshold: list_del.append(count) count += 1 else: count += 1 class_map_ko = np.delete(class_map_ko, list_del, axis=1) count = 0 for key in class_map: arr = class_map_ko[count, :] class_map[key] = arr.tolist() count += 1 num_classes = class_map_ko.shape[1] elif hi_num == 1: class_map = class_map2 class_map_ko_1 = construct_class_numpy(class_map) class_map_ko = construct_class_numpy(class_map) a = class_map_ko.sum(axis=0) count = 0 list_del = [] for i in a: if i < ko_threshold2: list_del.append(count) count += 1 else: count += 1 class_map_ko = np.delete(class_map_ko, list_del, axis=1) count = 0 for key in class_map: arr = class_map_ko[count, :] class_map[key] = arr.tolist() count += 1 num_classes = class_map_ko.shape[1] elif hi_num == 2: class_map = class_map3 class_map_ko_2 = construct_class_numpy(class_map) class_map_ko = construct_class_numpy(class_map) a = class_map_ko.sum(axis=0) count = 0 list_del = [] for i in a: if i > ko_threshold2: list_del.append(count) count += 1 else: count += 1 class_map_ko = np.delete(class_map_ko, list_del, axis=1) count = 0 for key in class_map: arr = class_map_ko[count, :] class_map[key] = arr.tolist() count += 1 num_classes = class_map_ko.shape[1] OTU_ko_num = class_map_ko.sum(axis=1) count = 0 for num in OTU_ko_num: if num < 100: count += 1 ko_cb = construct_class_para(class_map_ko, 0, FLAGS.beta1) ko_cb = tf.cast(ko_cb, tf.float32) f1_par = construct_class_para(class_map_ko, 1, FLAGS.beta2) context_pairs = train_data[3] if FLAGS.random_context else None placeholders = construct_placeholders(num_classes) minibatch = NodeMinibatchIterator(G, id_map, placeholders, class_map, num_classes, batch_size=FLAGS.batch_size, max_degree=FLAGS.max_degree, context_pairs=context_pairs) with open('test_nodes.txt', 'w') as f: json.dump(minibatch.test_nodes, f) ########### list_node = minibatch.nodes for otu in minibatch.train_nodes: if otu in list_node: list_node.remove(otu) for otu in minibatch.val_nodes: if otu in list_node: list_node.remove(otu) for otu in minibatch.test_nodes: if otu in list_node: list_node.remove(otu) ########### if hi_num == 0: adj_info_ph = tf.placeholder(tf.int32, shape=minibatch.adj.shape) # 把adj_info设成Variable应该是因为在训练和测试时会改变adj_info的值,所以 # 用Varible然后用tf.assign()赋值。 adj_info = tf.Variable(adj_info_ph, trainable=False, name="adj_info") if FLAGS.model == 'graphsage_mean': # Create model sampler = UniformNeighborSampler(adj_info) if FLAGS.samples_3 != 0: layer_infos = [ SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1), SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2), SAGEInfo("node", sampler, FLAGS.samples_3, FLAGS.dim_2) ] elif FLAGS.samples_2 != 0: layer_infos = [ SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1), SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2) ] else: layer_infos = [ SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1) ] model = SupervisedGraphsage( num_classes, placeholders, features, adj_info, minibatch.deg, # 每一个的度 layer_infos, ko_cb, hi_num, model_size=FLAGS.model_size, sigmoid_loss=FLAGS.sigmoid, identity_dim=FLAGS.identity_dim, logging=True, concat=False) elif FLAGS.model == 'gcn': # Create model sampler = UniformNeighborSampler(adj_info) layer_infos = [ SAGEInfo("node", sampler, FLAGS.samples_1, 2 * FLAGS.dim_1), SAGEInfo("node", sampler, FLAGS.samples_2, 2 * FLAGS.dim_2) ] model = SupervisedGraphsage(num_classes, placeholders, features, adj_info, minibatch.deg, layer_infos=layer_infos, aggregator_type="gcn", model_size=FLAGS.model_size, concat=False, sigmoid_loss=FLAGS.sigmoid, identity_dim=FLAGS.identity_dim, logging=True) elif FLAGS.model == 'graphsage_seq': sampler = UniformNeighborSampler(adj_info) layer_infos = [ SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1), SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2) ] model = SupervisedGraphsage(num_classes, placeholders, features, adj_info, minibatch.deg, layer_infos=layer_infos, aggregator_type="seq", model_size=FLAGS.model_size, sigmoid_loss=FLAGS.sigmoid, identity_dim=FLAGS.identity_dim, logging=True, concat=True) elif FLAGS.model == 'graphsage_maxpool': sampler = UniformNeighborSampler(adj_info) layer_infos = [ SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1), SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2) ] model = SupervisedGraphsage(num_classes, placeholders, features, adj_info, minibatch.deg, layer_infos=layer_infos, aggregator_type="maxpool", model_size=FLAGS.model_size, sigmoid_loss=FLAGS.sigmoid, identity_dim=FLAGS.identity_dim, logging=True, concat=True) elif FLAGS.model == 'graphsage_meanpool': sampler = UniformNeighborSampler(adj_info) layer_infos = [ SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1), SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2) ] model = SupervisedGraphsage(num_classes, placeholders, features, adj_info, minibatch.deg, layer_infos=layer_infos, aggregator_type="meanpool", model_size=FLAGS.model_size, sigmoid_loss=FLAGS.sigmoid, identity_dim=FLAGS.identity_dim, logging=True, concat=True) elif FLAGS.model == 'gat': sampler = UniformNeighborSampler(adj_info) # 建立两层网络 采样邻居、邻居个数、输出维度 layer_infos = [ SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1), SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2) ] model = SupervisedGraphsage( num_classes, placeholders, features, adj_info, minibatch.deg, concat=True, layer_infos=layer_infos, aggregator_type="gat", model_size=FLAGS.model_size, sigmoid_loss=FLAGS.sigmoid, identity_dim=FLAGS.identity_dim, logging=True, ) else: raise Exception('Error: model name unrecognized.') config = tf.ConfigProto( log_device_placement=FLAGS.log_device_placement) config.gpu_options.allow_growth = True config.gpu_options.per_process_gpu_memory_fraction = GPU_MEM_FRACTION config.allow_soft_placement = True # Initialize session sess = tf.Session(config=config) # sess = tf_dbg.LocalCLIDebugWrapperSession(sess) #merged = tf.summary.merge_all() # 将所有东西保存到磁盘,可视化会用到 #summary_writer = tf.summary.FileWriter(log_dir(), sess.graph) # 记录信息,可视化,可以用tensorboard查看 # Init variables sess.run(tf.global_variables_initializer(), feed_dict={adj_info_ph: minibatch.adj}) #sess.run(tf.global_variables_initializer(), feed_dict={adj_info_ph2: minibatch2.adj}) # Train model total_steps = 0 avg_time = 0.0 epoch_val_costs = [] epoch_val_costs2 = [] # 这里minibatch.adj和minibathc.test_adj的大小是一样的,只不过adj里面把不是train的值都变成一样 # val在这里是validation的意思,验证 train_adj_info = tf.assign( adj_info, minibatch.adj ) # tf.assign()是为一个tf.Variable赋值,返回值是一个Variable,是赋值后的值 val_adj_info = tf.assign( adj_info, minibatch.test_adj) # assign()是一个Opration,要用sess.run()才能执行 it = 0 train_loss = [] val_loss = [] train_f1_mics = [] val_f1_mics = [] loss_plt = [] loss_plt2 = [] trainf1mi = [] trainf1ma = [] valf1mi = [] valf1ma = [] iter_num = 0 for epoch in range(FLAGS.epochs * 2): if epoch < FLAGS.epochs: minibatch.shuffle() iter = 0 print('Epoch: %04d' % (epoch + 1)) epoch_val_costs.append(0) while not minibatch.end(): # Construct feed dictionary # 通过改变feed_dict来改变每次minibatch的节点 feed_dict, labels = minibatch.next_minibatch_feed_dict( ) # feed_dict是mibatch修改过的placeholder feed_dict.update({placeholders['dropout']: FLAGS.dropout}) t = time.time() # Training step outs = sess.run([model.opt_op, model.loss, model.preds], feed_dict=feed_dict) train_cost = outs[1] iter_num = iter_num + 1 loss_plt.append(float(train_cost)) if iter % FLAGS.print_every == 0: # Validation 验证集 sess.run(val_adj_info.op ) # sess.run() fetch参数是一个Opration,代表执行这个操作。 if FLAGS.validate_batch_size == -1: val_cost, val_f1_mic, val_f1_mac, duration, otu_lazy, _, val_preds, __, val_accuracy, val_mi_roc_auc = incremental_evaluate( sess, model, minibatch, f1_par, FLAGS.batch_size) else: val_cost, val_f1_mic, val_f1_mac, duration, val_accuracy, val_mi_roc_auc = evaluate( sess, model, minibatch, f1_par, FLAGS.validate_batch_size) sess.run(train_adj_info.op ) # 每一个tensor都有op属性,代表产生这个张量的opration。 epoch_val_costs[-1] += val_cost #if iter % FLAGS.print_every == 0: #summary_writer.add_summary(outs[0], total_steps) # Print results avg_time = (avg_time * total_steps + time.time() - t) / (total_steps + 1) loss_plt2.append(float(val_cost)) valf1mi.append(float(val_f1_mic)) valf1ma.append(float(val_f1_mac)) if iter % FLAGS.print_every == 0: train_f1_mic, train_f1_mac, train_f1_none, train_accuracy, train_mi_roc_auc = calc_f1( labels, outs[-1], f1_par) trainf1mi.append(float(train_f1_mic)) trainf1ma.append(float(train_f1_mac)) print( "Iter:", '%04d' % iter, # 训练集上的损失函数等信息 "train_loss=", "{:.5f}".format(train_cost), "train_f1_mic=", "{:.5f}".format(train_f1_mic), "train_f1_mac=", "{:.5f}".format(train_f1_mac), "train_accuracy=", "{:.5f}".format(train_accuracy), "train_ra_mi=", "{:.5f}".format(train_mi_roc_auc), # 在测试集上的损失函数值等信息 "val_loss=", "{:.5f}".format(val_cost), "val_f1_mic=", "{:.5f}".format(val_f1_mic), "val_f1_mac=", "{:.5f}".format(val_f1_mac), "val_accuracy=", "{:.5f}".format(val_accuracy), "val_ra_mi=", "{:.5f}".format(val_mi_roc_auc), "time=", "{:.5f}".format(avg_time)) train_loss.append(train_cost) val_loss.append(val_cost) train_f1_mics.append(train_f1_mic) val_f1_mics.append(val_f1_mic) iter += 1 total_steps += 1 if total_steps > FLAGS.max_total_steps: break if total_steps > FLAGS.max_total_steps: break ################################################################################################################### # begin second degree training ################################################################################################################### """"" else: minibatch2.shuffle() iter = 0 print('Epoch2: %04d' % (epoch + 1)) epoch_val_costs2.append(0) while not minibatch2.end(): # Construct feed dictionary # 通过改变feed_dict来改变每次minibatch的节点 feed_dict, labels = minibatch2.next_minibatch_feed_dict() # feed_dict是mibatch修改过的placeholder feed_dict.update({placeholders2['dropout']: FLAGS.dropout}) t = time.time() # Training step #global model2 outs = sess.run([merged, model2.opt_op, model2.loss, model2.preds], feed_dict=feed_dict) train_cost = outs[2] iter_num = iter_num + 1 loss_plt.append(float(train_cost)) if iter % FLAGS.print_every == 0: # Validation 验证集 sess.run(val_adj_info2.op) # sess.run() fetch参数是一个Opration,代表执行这个操作。 if FLAGS.validate_batch_size == -1: val_cost, val_f1_mic, val_f1_mac, duration, otu_lazy = incremental_evaluate(sess, model2, minibatch2, FLAGS.batch_size) else: val_cost, val_f1_mic, val_f1_mac, duration = evaluate(sess, model2, minibatch2, FLAGS.validate_batch_size) sess.run(train_adj_info2.op) # 每一个tensor都有op属性,代表产生这个张量的opration。 epoch_val_costs2[-1] += val_cost if iter % FLAGS.print_every == 0: summary_writer.add_summary(outs[0], total_steps) # Print results avg_time = (avg_time * total_steps + time.time() - t) / (total_steps + 1) loss_plt2.append(float(val_cost)) valf1mi.append(float(val_f1_mic)) valf1ma.append(float(val_f1_mac)) if iter % FLAGS.print_every == 0: train_f1_mic, train_f1_mac = calc_f1(labels, outs[-1]) trainf1mi.append(float(train_f1_mic)) trainf1ma.append(float(train_f1_mac)) print("Iter:", '%04d' % iter, # 训练集上的损失函数等信息 "train_loss=", "{:.5f}".format(train_cost), "train_f1_mic=", "{:.5f}".format(train_f1_mic), "train_f1_mac=", "{:.5f}".format(train_f1_mac), # 在测试集上的损失函数值等信息 "val_loss=", "{:.5f}".format(val_cost), "val_f1_mic=", "{:.5f}".format(val_f1_mic), "val_f1_mac=", "{:.5f}".format(val_f1_mac), "time=", "{:.5f}".format(avg_time)) train_loss.append(train_cost) val_loss.append(val_cost) train_f1_mics.append(train_f1_mic) val_f1_mics.append(val_f1_mic) iter += 1 total_steps += 1 if total_steps > FLAGS.max_total_steps: break if total_steps > FLAGS.max_total_steps: break """ print("Optimization Finished!") sess.run(val_adj_info.op) if hi_num == 1: last_preds = test_preds last_labels = test_labels val_cost, val_f1_mic, val_f1_mac, duration, otu_f1, ko_none, test_preds, test_labels, test_accuracy, test_mi_roc_auc = incremental_evaluate( sess, model, minibatch, f1_par, FLAGS.batch_size, test=True) print( "Full validation stats:", "loss=", "{:.5f}".format(val_cost), "f1_micro=", "{:.5f}".format(val_f1_mic), "f1_macro=", "{:.5f}".format(val_f1_mac), "accuracy=", "{:.5f}".format(test_accuracy), "roc_auc_mi=", "{:.5f}".format(test_mi_roc_auc), "time=", "{:.5f}".format(duration), ) if hi_num == 1: # update test preds """ ab_ko = json.load(open(FLAGS.train_prefix + "-below1500_ko_idx.json")) #ab_ko = construct_class_numpy(ab_ko) f1_par = construct_class_para(class_map_ko_0, 1, FLAGS.beta2) i = 0 for col in ab_ko: last_preds[..., col] = test_preds[..., i] i += 1 f1_scores = calc_f1(last_preds, last_labels, f1_par) """ f1_par = construct_class_para(class_map_ko_0, 1, FLAGS.beta2) final_preds = np.hstack((last_preds, test_preds)) final_labels = np.hstack((last_labels, test_labels)) elif hi_num == 2: f1_par = construct_class_para(class_map_ko_0, 1, FLAGS.beta2) final_preds = np.hstack((final_preds, test_preds)) final_labels = np.hstack((final_labels, test_labels)) f1_scores = calc_f1(final_preds, final_labels, f1_par) print('\n', 'Hierarchy combination f1 score:') print("f1_micro=", "{:.5f}".format(f1_scores[0]), "f1_macro=", "{:.5f}".format(f1_scores[1]), "accuracy=", "{:.5f}".format(f1_scores[3]), "roc_auc_mi=", "{:.5f}".format(f1_scores[4])) pred = y_ture_pre(sess, model, minibatch, FLAGS.batch_size) for i in range(pred.shape[0]): sum = 0 for l in range(pred.shape[1]): sum = sum + pred[i, l] for m in range(pred.shape[1]): pred[i, m] = pred[i, m] / sum id = json.load(open(FLAGS.train_prefix + "-id_map.json")) # x_train = np.empty([pred.shape[0], array.s) num = 0 session = tf.Session() array = session.run(features) x_test = np.empty([pred.shape[0], array.shape[1]]) x_train = np.empty([len(minibatch.train_nodes), array.shape[1]]) for node in minibatch.val_nodes: x_test[num] = array[id[node]] num = num + 1 num1 = 0 for node in minibatch.train_nodes: x_train[num1] = array[id[node]] num1 = num1 + 1 with open(log_dir() + "val_stats.txt", "w") as fp: fp.write("loss={:.5f} f1_micro={:.5f} f1_macro={:.5f} time={:.5f}". format(val_cost, val_f1_mic, val_f1_mac, duration)) print("Writing test set stats to file (don't peak!)") val_cost, val_f1_mic, val_f1_mac, duration, otu_lazy, ko_none, _, __, test_accuracy, test_mi_roc_auc = incremental_evaluate( sess, model, minibatch, f1_par, FLAGS.batch_size, test=True) with open(log_dir() + "test_stats.txt", "w") as fp: fp.write("loss={:.5f} f1_micro={:.5f} f1_macro={:.5f}".format( val_cost, val_f1_mic, val_f1_mac)) incremental_evaluate_for_each(sess, model, minibatch, FLAGS.batch_size, test=True) ################################################################################################################## # plot loss plt.figure() plt.plot(loss_plt, label='train_loss') plt.plot(loss_plt2, label='val_loss') plt.legend(loc=0) plt.xlabel('Iteration') plt.ylabel('loss') plt.title('Loss plot') plt.grid(True) plt.axis('tight') #plt.savefig("./graph/HMC_SAGE_CB_loss.png") # plt.show() # plot f1 score plt.figure() plt.subplot(211) plt.plot(trainf1mi, label='train_f1_micro') plt.plot(valf1mi, label='val_f1_micro') plt.legend(loc=0) plt.xlabel('Iterations') plt.ylabel('f1_micro') plt.title('train_val_f1_score') plt.grid(True) plt.axis('tight') plt.subplot(212) plt.plot(trainf1ma, label='train_f1_macro') plt.plot(valf1ma, label='val_f1_macro') plt.legend(loc=0) plt.xlabel('Iteration') plt.ylabel('f1_macro') plt.grid(True) plt.axis('tight') # plt.savefig("./graph/HMC_SAGE_CB_f1.png") # plt.show() plt.figure() plt.plot(np.arange(len(train_loss)) + 1, train_loss, label='train') plt.plot(np.arange(len(val_loss)) + 1, val_loss, label='val') plt.legend() plt.savefig('loss.png') plt.figure() plt.plot(np.arange(len(train_f1_mics)) + 1, train_f1_mics, label='train') plt.plot(np.arange(len(val_f1_mics)) + 1, val_f1_mics, label='val') plt.legend() #plt.savefig('f1.png') # OTU f1 plt.figure() plt.plot(otu_f1, label='otu_f1') plt.legend(loc=0) plt.xlabel('OTU') plt.ylabel('f1_score') plt.title('OTU f1 plot') plt.grid(True) plt.axis('tight') #plt.savefig("./graph/below_1500_CECB15_otu_f1.png") # plt.show() ko_none = f1_scores[2] # Ko f1 score plt.figure() plt.plot(ko_none, label='Ko f1 score') plt.legend(loc=0) plt.xlabel('Ko') plt.ylabel('f1_score') plt.grid(True) plt.axis('tight') #plt.savefig("./graph/below1500_CECB15_ko_f1.png") bad_ko = [] b02 = 0 b05 = 0 b07 = 0 for i in range(len(ko_none)): if ko_none[i] < 0.2: bad_ko.append(i) b02 += 1 elif ko_none[i] < 0.5: b05 += 1 elif ko_none[i] < 0.7: b07 += 1 print("ko f1 below 0.2:", b02) print("ko f1 below 0.5:", b05) print("ko f1 below 0.7:", b07) print("ko f1 over 0.7:", len(ko_none) - b02 - b05 - b07) bad_ko = np.array(bad_ko) with open('./new_data_badko/graph7 ko below zero point two .txt', 'w') as f: np.savetxt(f, bad_ko, fmt='%d', delimiter=",")
def train(train_data, test_data=None, sampler_name='Uniform'): G = train_data[0] features = train_data[1] id_map = train_data[2] class_map = train_data[4] if isinstance(list(class_map.values())[0], list): num_classes = len(list(class_map.values())[0]) else: num_classes = len(set(class_map.values())) if not features is None: # pad with dummy zero vector features = np.vstack([features, np.zeros((features.shape[1], ))]) context_pairs = train_data[3] if FLAGS.random_context else None placeholders = construct_placeholders(num_classes) minibatch = NodeMinibatchIterator(G, id_map, placeholders, class_map, num_classes, batch_size=FLAGS.batch_size, max_degree=FLAGS.max_degree, context_pairs=context_pairs) adj_info_ph = tf.placeholder(tf.int32, shape=minibatch.adj.shape) adj_info = tf.Variable(adj_info_ph, trainable=False, name="adj_info") adj_shape = adj_info.get_shape().as_list() # loss_node = tf.SparseTensor(indices=np.empty((0,2), dtype=np.int64), values=[], dense_shape=[adj_shape[0], adj_shape[0]]) # loss_node_count = tf.SparseTensor(indices=np.empty((0,2), dtype=np.int64), values=[], dense_shape=[adj_shape[0], adj_shape[0]]) # # newly added for storing cost in each adj cell # loss_node = tf.Variable(tf.zeros([minibatch.adj.shape[0], minibatch.adj.shape[0]]), trainable=False, name="loss_node", dtype=tf.float32) # loss_node_count = tf.Variable(tf.zeros([minibatch.adj.shape[0], minibatch.adj.shape[0]]), trainable=False, name="loss_node_count", dtype=tf.float32) if FLAGS.model == 'mean_concat': # Create model if sampler_name == 'Uniform': sampler = UniformNeighborSampler(adj_info) elif sampler_name == 'ML': sampler = MLNeighborSampler(adj_info, features) elif sampler_name == 'FastML': sampler = FastMLNeighborSampler(adj_info, features) if FLAGS.samples_3 != 0: layer_infos = [ SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1), SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2), SAGEInfo("node", sampler, FLAGS.samples_3, FLAGS.dim_3) ] elif FLAGS.samples_2 != 0: layer_infos = [ SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1), SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2) ] else: layer_infos = [ SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1) ] ''' ### 3 layer test layer_infos = [SAGEInfo("node", sampler, 50, FLAGS.dim_2), SAGEInfo("node", sampler, 25, FLAGS.dim_2), SAGEInfo("node", sampler, 10, FLAGS.dim_2)] ''' # modified model = SupervisedGraphsage( num_classes, placeholders, features, adj_info, #loss_node, #loss_node_count, minibatch.deg, layer_infos, concat=True, model_size=FLAGS.model_size, sigmoid_loss=FLAGS.sigmoid, identity_dim=FLAGS.identity_dim, logging=True) # # model = SupervisedGraphsage(num_classes, placeholders, # features, # adj_info, # minibatch.deg, # layer_infos, # model_size=FLAGS.model_size, # sigmoid_loss = FLAGS.sigmoid, # identity_dim = FLAGS.identity_dim, # logging=True) elif FLAGS.model == 'mean_add': # Create model if sampler_name == 'Uniform': sampler = UniformNeighborSampler(adj_info) elif sampler_name == 'ML': sampler = MLNeighborSampler(adj_info, features) elif sampler_name == 'FastML': sampler = FastMLNeighborSampler(adj_info, features) if FLAGS.samples_3 != 0: layer_infos = [ SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1), SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2), SAGEInfo("node", sampler, FLAGS.samples_3, FLAGS.dim_3) ] elif FLAGS.samples_2 != 0: layer_infos = [ SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1), SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2) ] else: layer_infos = [ SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1) ] ''' ### 3 layer test layer_infos = [SAGEInfo("node", sampler, 50, FLAGS.dim_2), SAGEInfo("node", sampler, 25, FLAGS.dim_2), SAGEInfo("node", sampler, 10, FLAGS.dim_2)] ''' # modified model = SupervisedGraphsage( num_classes, placeholders, features, adj_info, #loss_node, #loss_node_count, minibatch.deg, layer_infos, concat=False, model_size=FLAGS.model_size, sigmoid_loss=FLAGS.sigmoid, identity_dim=FLAGS.identity_dim, logging=True) # # model = SupervisedGraphsage(num_classes, placeholders, # features, # adj_info, # minibatch.deg, # layer_infos, # model_size=FLAGS.model_size, # sigmoid_loss = FLAGS.sigmoid, # identity_dim = FLAGS.identity_dim, # logging=True) elif FLAGS.model == 'LRmean_add': # Create model if sampler_name == 'Uniform': sampler = UniformNeighborSampler(adj_info) elif sampler_name == 'ML': sampler = MLNeighborSampler(adj_info, features) elif sampler_name == 'FastML': sampler = FastMLNeighborSampler(adj_info, features) if FLAGS.samples_3 != 0: layer_infos = [ SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1), SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2), SAGEInfo("node", sampler, FLAGS.samples_3, FLAGS.dim_3) ] elif FLAGS.samples_2 != 0: layer_infos = [ SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1), SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2) ] else: layer_infos = [ SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1) ] ''' ### 3 layer test layer_infos = [SAGEInfo("node", sampler, 50, FLAGS.dim_2), SAGEInfo("node", sampler, 25, FLAGS.dim_2), SAGEInfo("node", sampler, 10, FLAGS.dim_2)] ''' # modified model = SupervisedGraphsage( num_classes, placeholders, features, adj_info, #loss_node, #loss_node_count, minibatch.deg, layer_infos, aggregator_type="LRmean", concat=False, model_size=FLAGS.model_size, sigmoid_loss=FLAGS.sigmoid, identity_dim=FLAGS.identity_dim, logging=True) # # model = SupervisedGraphsage(num_classes, placeholders, # features, # adj_info, # minibatch.deg, # layer_infos, # model_size=FLAGS.model_size, # sigmoid_loss = FLAGS.sigmoid, # identity_dim = FLAGS.identity_dim, # logging=True) elif FLAGS.model == 'logicmean': # Create model if sampler_name == 'Uniform': sampler = UniformNeighborSampler(adj_info) elif sampler_name == 'ML': sampler = MLNeighborSampler(adj_info, features) elif sampler_name == 'FastML': sampler = FastMLNeighborSampler(adj_info, features) if FLAGS.samples_3 != 0: layer_infos = [ SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1), SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2), SAGEInfo("node", sampler, FLAGS.samples_3, FLAGS.dim_2) ] elif FLAGS.samples_2 != 0: layer_infos = [ SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1), SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2) ] else: layer_infos = [ SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1) ] ''' ### 3 layer test layer_infos = [SAGEInfo("node", sampler, 50, FLAGS.dim_2), SAGEInfo("node", sampler, 25, FLAGS.dim_2), SAGEInfo("node", sampler, 10, FLAGS.dim_2)] ''' # modified model = SupervisedGraphsage( num_classes, placeholders, features, adj_info, #loss_node, #loss_node_count, minibatch.deg, layer_infos, aggregator_type='logicmean', concat=True, model_size=FLAGS.model_size, sigmoid_loss=FLAGS.sigmoid, identity_dim=FLAGS.identity_dim, logging=True) # elif FLAGS.model == 'attmean': # Create model if sampler_name == 'Uniform': sampler = UniformNeighborSampler(adj_info) elif sampler_name == 'ML': sampler = MLNeighborSampler(adj_info, features) elif sampler_name == 'FastML': sampler = FastMLNeighborSampler(adj_info, features) if FLAGS.samples_3 != 0: layer_infos = [ SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1), SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2), SAGEInfo("node", sampler, FLAGS.samples_3, FLAGS.dim_2) ] elif FLAGS.samples_2 != 0: layer_infos = [ SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1), SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2) ] else: layer_infos = [ SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1) ] ''' ### 3 layer test layer_infos = [SAGEInfo("node", sampler, 50, FLAGS.dim_2), SAGEInfo("node", sampler, 25, FLAGS.dim_2), SAGEInfo("node", sampler, 10, FLAGS.dim_2)] ''' # modified model = SupervisedGraphsage( num_classes, placeholders, features, adj_info, #loss_node, #loss_node_count, minibatch.deg, layer_infos, aggregator_type='attmean', model_size=FLAGS.model_size, sigmoid_loss=FLAGS.sigmoid, identity_dim=FLAGS.identity_dim, logging=True) # # model = SupervisedGraphsage(num_classes, placeholders, # features, # adj_info, # minibatch.deg, # layer_infos, # model_size=FLAGS.model_size, # sigmoid_loss = FLAGS.sigmoid, # identity_dim = FLAGS.identity_dim, # logging=True) elif FLAGS.model == 'gcn': if sampler_name == 'Uniform': sampler = UniformNeighborSampler(adj_info) elif sampler_name == 'ML': sampler = MLNeighborSampler(adj_info, features) elif sampler_name == 'FastML': sampler = FastMLNeighborSampler(adj_info, features) layer_infos = [ SAGEInfo("node", sampler, FLAGS.samples_1, 2 * FLAGS.dim_1), SAGEInfo("node", sampler, FLAGS.samples_2, 2 * FLAGS.dim_2) ] model = SupervisedGraphsage(num_classes, placeholders, features, adj_info, minibatch.deg, layer_infos=layer_infos, aggregator_type="gcn", model_size=FLAGS.model_size, concat=False, sigmoid_loss=FLAGS.sigmoid, identity_dim=FLAGS.identity_dim, logging=True) elif FLAGS.model == 'graphsage_seq': sampler = UniformNeighborSampler(adj_info) layer_infos = [ SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1), SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2) ] model = SupervisedGraphsage(num_classes, placeholders, features, adj_info, minibatch.deg, layer_infos=layer_infos, aggregator_type="seq", model_size=FLAGS.model_size, sigmoid_loss=FLAGS.sigmoid, identity_dim=FLAGS.identity_dim, logging=True) elif FLAGS.model == 'graphsage_maxpool': if sampler_name == 'Uniform': sampler = UniformNeighborSampler(adj_info) elif sampler_name == 'ML': sampler = MLNeighborSampler(adj_info, features) elif sampler_name == 'FastML': sampler = FastMLNeighborSampler(adj_info, features) #sampler = UniformNeighborSampler(adj_info) layer_infos = [ SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1), SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2) ] model = SupervisedGraphsage(num_classes, placeholders, features, adj_info, minibatch.deg, layer_infos=layer_infos, aggregator_type="maxpool", model_size=FLAGS.model_size, sigmoid_loss=FLAGS.sigmoid, identity_dim=FLAGS.identity_dim, logging=True) elif FLAGS.model == 'graphsage_meanpool': if sampler_name == 'Uniform': sampler = UniformNeighborSampler(adj_info) elif sampler_name == 'ML': sampler = MLNeighborSampler(adj_info, features) elif sampler_name == 'FastML': sampler = FastMLNeighborSampler(adj_info, features) #sampler = UniformNeighborSampler(adj_info) layer_infos = [ SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1), SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2) ] model = SupervisedGraphsage(num_classes, placeholders, features, adj_info, minibatch.deg, layer_infos=layer_infos, aggregator_type="meanpool", model_size=FLAGS.model_size, sigmoid_loss=FLAGS.sigmoid, identity_dim=FLAGS.identity_dim, logging=True) else: raise Exception('Error: model name unrecognized.') config = tf.ConfigProto(log_device_placement=FLAGS.log_device_placement) config.gpu_options.allow_growth = True #config.gpu_options.per_process_gpu_memory_fraction = GPU_MEM_FRACTION config.allow_soft_placement = True # Initialize session sess = tf.Session(config=config) merged = tf.summary.merge_all() summary_writer = tf.summary.FileWriter(log_dir(sampler_name), sess.graph) # Save model saver = tf.train.Saver() model_path = './model/' + FLAGS.train_prefix.split( '/')[-1] + '-' + FLAGS.model_prefix + '-' + sampler_name model_path += "/{model:s}_{model_size:s}_{lr:0.4f}/".format( model=FLAGS.model, model_size=FLAGS.model_size, lr=FLAGS.learning_rate) if not os.path.exists(model_path): os.makedirs(model_path) # Init variables sess.run(tf.global_variables_initializer(), feed_dict={adj_info_ph: minibatch.adj}) # Restore params of ML sampler model if sampler_name == 'ML' or sampler_name == 'FastML': sampler_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="MLsampler") #pdb.set_trace() saver_sampler = tf.train.Saver(var_list=sampler_vars) sampler_model_path = './model/MLsampler-' + FLAGS.train_prefix.split( '/')[-1] + '-' + FLAGS.model_prefix sampler_model_path += "/{model:s}_{model_size:s}_{lr:0.4f}/".format( model=FLAGS.model, model_size=FLAGS.model_size, lr=FLAGS.learning_rate) saver_sampler.restore(sess, sampler_model_path + 'model.ckpt') # Train model total_steps = 0 avg_time = 0.0 epoch_val_costs = [] train_adj_info = tf.assign(adj_info, minibatch.adj) val_adj_info = tf.assign(adj_info, minibatch.test_adj) val_cost_ = [] val_f1_mic_ = [] val_f1_mac_ = [] duration_ = [] ln_acc = sparse.csr_matrix((adj_shape[0], adj_shape[0]), dtype=np.float32) lnc_acc = sparse.csr_matrix((adj_shape[0], adj_shape[0]), dtype=np.int32) ln_acc = ln_acc.tolil() lnc_acc = lnc_acc.tolil() # # ln_acc = np.zeros([adj_shape[0], adj_shape[0]]) # lnc_acc = np.zeros([adj_shape[0], adj_shape[0]]) for epoch in range(FLAGS.epochs): minibatch.shuffle() iter = 0 print('Epoch: %04d' % (epoch + 1)) epoch_val_costs.append(0) #for j in range(2): while not minibatch.end(): # Construct feed dictionary feed_dict, labels = minibatch.next_minibatch_feed_dict() if feed_dict.values()[0] != FLAGS.batch_size: break feed_dict.update({placeholders['dropout']: FLAGS.dropout}) t = time.time() # Training step #outs = sess.run([merged, model.opt_op, model.loss, model.preds], feed_dict=feed_dict) outs = sess.run([ merged, model.opt_op, model.loss, model.preds, model.loss_node, model.loss_node_count, model.out_mean ], feed_dict=feed_dict) train_cost = outs[2] if iter % FLAGS.validate_iter == 0: # Validation sess.run(val_adj_info.op) if FLAGS.validate_batch_size == -1: val_cost, val_f1_mic, val_f1_mac, duration = incremental_evaluate( sess, model, minibatch, FLAGS.batch_size) else: val_cost, val_f1_mic, val_f1_mac, duration = evaluate( sess, model, minibatch, FLAGS.validate_batch_size) # accumulate val results val_cost_.append(val_cost) val_f1_mic_.append(val_f1_mic) val_f1_mac_.append(val_f1_mac) duration_.append(duration) # sess.run(train_adj_info.op) epoch_val_costs[-1] += val_cost if total_steps % FLAGS.print_every == 0: summary_writer.add_summary(outs[0], total_steps) # Print results avg_time = (avg_time * total_steps + time.time() - t) / (total_steps + 1) # loss_node #import pdb #pdb.set_trace() # if epoch > 0.7*FLAGS.epochs: # ln = outs[-2].values # ln_idx = outs[-2].indices # ln_acc[ln_idx[:,0], ln_idx[:,1]] += ln # # # lnc = outs[-1].values # lnc_idx = outs[-1].indices # lnc_acc[lnc_idx[:,0], lnc_idx[:,1]] += lnc ln = outs[4].values ln_idx = outs[4].indices ln_acc[ln_idx[:, 0], ln_idx[:, 1]] += ln lnc = outs[5].values lnc_idx = outs[5].indices lnc_acc[lnc_idx[:, 0], lnc_idx[:, 1]] += lnc #pdb.set_trace() #idx = np.where(lnc_acc != 0) #loss_node_mean = (ln_acc[idx[0], idx[1]]).mean() #loss_node_count_mean = (lnc_acc[idx[0], idx[1]]).mean() if total_steps % FLAGS.print_every == 0: train_f1_mic, train_f1_mac = calc_f1(labels, outs[3]) print( "Iter:", '%04d' % iter, "train_loss=", "{:.5f}".format(train_cost), "train_f1_mic=", "{:.5f}".format(train_f1_mic), "train_f1_mac=", "{:.5f}".format(train_f1_mac), "val_loss=", "{:.5f}".format(val_cost), "val_f1_mic=", "{:.5f}".format(val_f1_mic), "val_f1_mac=", "{:.5f}".format(val_f1_mac), #"loss_node=", "{:.5f}".format(loss_node_mean), #"loss_node_count=", "{:.5f}".format(loss_node_count_mean), "time=", "{:.5f}".format(avg_time)) iter += 1 total_steps += 1 if total_steps > FLAGS.max_total_steps: break if total_steps > FLAGS.max_total_steps: break # Save model save_path = saver.save(sess, model_path + 'model.ckpt') print('model is saved at %s' % save_path) # Save loss node and count loss_node_path = './loss_node/' + FLAGS.train_prefix.split( '/')[-1] + '-' + FLAGS.model_prefix + '-' + sampler_name loss_node_path += "/{model:s}_{model_size:s}_{lr:0.4f}/".format( model=FLAGS.model, model_size=FLAGS.model_size, lr=FLAGS.learning_rate) if not os.path.exists(loss_node_path): os.makedirs(loss_node_path) loss_node = sparse.save_npz(loss_node_path + 'loss_node.npz', sparse.csr_matrix(ln_acc)) loss_node_count = sparse.save_npz(loss_node_path + 'loss_node_count.npz', sparse.csr_matrix(lnc_acc)) print('loss and count per node is saved at %s' % loss_node_path) # # save images of loss node and count # plt.imsave(loss_node_path + 'loss_node_mean.png', np.uint8(np.round(np.divide(ln_acc.todense()[:1024,:1024], lnc_acc.todense()[:1024,:1024]+1e-10))), cmap='jet', vmin=0, vmax=255) # plt.imsave(loss_node_path + 'loss_node_count.png', np.uint8(lnc_acc.todense()[:1024,:1024]), cmap='jet', vmin=0, vmax=255) # print("Validation per epoch in training") for ep in range(FLAGS.epochs): print("Epoch: %04d" % ep, " val_cost={:.5f}".format(val_cost_[ep]), " val_f1_mic={:.5f}".format(val_f1_mic_[ep]), " val_f1_mac={:.5f}".format(val_f1_mac_[ep]), " duration={:.5f}".format(duration_[ep])) print("Optimization Finished!") sess.run(val_adj_info.op) # full validation val_cost_ = [] val_f1_mic_ = [] val_f1_mac_ = [] duration_ = [] for iter in range(10): val_cost, val_f1_mic, val_f1_mac, duration = incremental_evaluate( sess, model, minibatch, FLAGS.batch_size) print("Full validation stats:", "loss=", "{:.5f}".format(val_cost), "f1_micro=", "{:.5f}".format(val_f1_mic), "f1_macro=", "{:.5f}".format(val_f1_mac), "time=", "{:.5f}".format(duration)) val_cost_.append(val_cost) val_f1_mic_.append(val_f1_mic) val_f1_mac_.append(val_f1_mac) duration_.append(duration) print("mean: loss={:.5f} f1_micro={:.5f} f1_macro={:.5f} time={:.5f}\n". format(np.mean(val_cost_), np.mean(val_f1_mic_), np.mean(val_f1_mac_), np.mean(duration_))) # write validation results with open(log_dir(sampler_name) + "val_stats.txt", "w") as fp: for iter in range(10): fp.write( "loss={:.5f} f1_micro={:.5f} f1_macro={:.5f} time={:.5f}\n". format(val_cost_[iter], val_f1_mic_[iter], val_f1_mac_[iter], duration_[iter])) fp.write( "mean: loss={:.5f} f1_micro={:.5f} f1_macro={:.5f} time={:.5f}\n". format(np.mean(val_cost_), np.mean(val_f1_mic_), np.mean(val_f1_mac_), np.mean(duration_))) fp.write( "variance: loss={:.5f} f1_micro={:.5f} f1_macro={:.5f} time={:.5f}\n" .format(np.var(val_cost_), np.var(val_f1_mic_), np.var(val_f1_mac_), np.var(duration_))) # test val_cost_ = [] val_f1_mic_ = [] val_f1_mac_ = [] duration_ = [] print("Writing test set stats to file (don't peak!)") # timeline if FLAGS.timeline == True: run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) run_metadata = tf.RunMetadata() else: run_options = None run_metadata = None for iter in range(10): val_cost, val_f1_mic, val_f1_mac, duration = incremental_evaluate( sess, model, minibatch, FLAGS.batch_size, run_options, run_metadata, test=True) #val_cost, val_f1_mic, val_f1_mac, duration = incremental_evaluate(sess, model, minibatch, FLAGS.batch_size, test=True) print("Full validation stats:", "loss=", "{:.5f}".format(val_cost), "f1_micro=", "{:.5f}".format(val_f1_mic), "f1_macro=", "{:.5f}".format(val_f1_mac), "time=", "{:.5f}".format(duration)) val_cost_.append(val_cost) val_f1_mic_.append(val_f1_mic) val_f1_mac_.append(val_f1_mac) duration_.append(duration) print("mean: loss={:.5f} f1_micro={:.5f} f1_macro={:.5f} time={:.5f}\n". format(np.mean(val_cost_), np.mean(val_f1_mic_), np.mean(val_f1_mac_), np.mean(duration_))) # write test results with open(log_dir(sampler_name) + "test_stats.txt", "w") as fp: for iter in range(10): fp.write( "loss={:.5f} f1_micro={:.5f} f1_macro={:.5f} time={:.5f}\n". format(val_cost_[iter], val_f1_mic_[iter], val_f1_mac_[iter], duration_[iter])) fp.write( "mean: loss={:.5f} f1_micro={:.5f} f1_macro={:.5f} time={:.5f}\n". format(np.mean(val_cost_), np.mean(val_f1_mic_), np.mean(val_f1_mac_), np.mean(duration_))) fp.write( "variance: loss={:.5f} f1_micro={:.5f} f1_macro={:.5f} time={:.5f}\n" .format(np.var(val_cost_), np.var(val_f1_mic_), np.var(val_f1_mac_), np.var(duration_))) # create timeline object, and write it to a json if FLAGS.timeline == True: tl = timeline.Timeline(run_metadata.step_stats) ctf = tl.generate_chrome_trace_format(show_memory=True) with open(log_dir(sampler_name) + 'timeline.json', 'w') as f: print('timeline written at %s' % (log_dir(sampler_name) + 'timelnie.json')) f.write(ctf) sess.close() tf.reset_default_graph()
def train(train_data, test_data=None): G = train_data[0] features = train_data[1] id_map = train_data[2] class_map = train_data[4] target_map = train_data[5] target_scaler = train_data[6] if isinstance(list(class_map.values())[0], list): num_classes = len(list(class_map.values())[0]) else: num_classes = len(set(class_map.values())) num_targets = 1 if not features is None: # pad with dummy zero vector features = np.vstack([features, np.zeros((features.shape[1], ))]) context_pairs = train_data[3] if FLAGS.random_context else None print('classification flag:', FLAGS.classification) print('regression flag:', FLAGS.regression) if FLAGS.classification and FLAGS.regression: Exception( "Either classification or regression set must be to True, not both." ) elif not FLAGS.classification and not FLAGS.regression: Exception("Either classification or regression set must be to True.") elif FLAGS.regression: print('Regression flag is set. Overwriting classmap and num_classes') class_map = target_map num_classes = num_targets placeholders = construct_placeholders(num_classes) minibatch = NodeMinibatchIterator(G, id_map, placeholders, class_map, num_classes, batch_size=FLAGS.batch_size, max_degree=FLAGS.max_degree, context_pairs=context_pairs) adj_info_ph = tf.compat.v1.placeholder(tf.int32, shape=minibatch.adj.shape) adj_info = tf.Variable(adj_info_ph, trainable=False, name="adj_info") if FLAGS.model == 'graphsage_mean': # Create model sampler = UniformNeighborSampler(adj_info) if FLAGS.samples_3 != 0: layer_infos = [ SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1), SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_1), # overwrite FLAGS.dim_2 SAGEInfo("node", sampler, FLAGS.samples_3, FLAGS.dim_1) ] # overwrite FLAGS.dim_2 elif FLAGS.samples_2 != 0: layer_infos = [ SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1), SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_1) ] # overwrite FLAGS.dim_2 else: layer_infos = [ SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1) ] model = SupervisedGraphsage(num_classes, placeholders, features, adj_info, minibatch.deg, layer_infos, model_size=FLAGS.model_size, sigmoid_loss=FLAGS.sigmoid, mse_loss=FLAGS.regression, identity_dim=FLAGS.identity_dim, logging=True) elif FLAGS.model == 'gcn': # Create model sampler = UniformNeighborSampler(adj_info) layer_infos = [ SAGEInfo("node", sampler, FLAGS.samples_1, 2 * FLAGS.dim_1), SAGEInfo("node", sampler, FLAGS.samples_2, 2 * FLAGS.dim_1) ] # overwrite FLAGS.dim_2 model = SupervisedGraphsage(num_classes, placeholders, features, adj_info, minibatch.deg, layer_infos=layer_infos, aggregator_type="gcn", model_size=FLAGS.model_size, concat=False, sigmoid_loss=FLAGS.sigmoid, mse_loss=FLAGS.regression, identity_dim=FLAGS.identity_dim, logging=True) elif FLAGS.model == 'graphsage_seq': sampler = UniformNeighborSampler(adj_info) layer_infos = [ SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1), SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_1) ] # overwrite FLAGS.dim_2 model = SupervisedGraphsage(num_classes, placeholders, features, adj_info, minibatch.deg, layer_infos=layer_infos, aggregator_type="seq", model_size=FLAGS.model_size, sigmoid_loss=FLAGS.sigmoid, mse_loss=FLAGS.regression, identity_dim=FLAGS.identity_dim, logging=True) elif FLAGS.model == 'graphsage_maxpool': sampler = UniformNeighborSampler(adj_info) layer_infos = [ SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1), SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_1) ] # overwrite FLAGS.dim_2 model = SupervisedGraphsage(num_classes, placeholders, features, adj_info, minibatch.deg, layer_infos=layer_infos, aggregator_type="maxpool", model_size=FLAGS.model_size, sigmoid_loss=FLAGS.sigmoid, mse_loss=FLAGS.regression, identity_dim=FLAGS.identity_dim, logging=True) elif FLAGS.model == 'graphsage_twomaxpool': sampler = UniformNeighborSampler(adj_info) layer_infos = [ SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1), SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_1) ] # overwrite FLAGS.dim_2 model = SupervisedGraphsage(num_classes, placeholders, features, adj_info, minibatch.deg, layer_infos=layer_infos, aggregator_type="twomaxpool", model_size=FLAGS.model_size, sigmoid_loss=FLAGS.sigmoid, mse_loss=FLAGS.regression, identity_dim=FLAGS.identity_dim, logging=True) elif FLAGS.model == 'graphsage_meanpool': sampler = UniformNeighborSampler(adj_info) layer_infos = [ SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1), SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_1) ] # overwrite FLAGS.dim_2 model = SupervisedGraphsage(num_classes, placeholders, features, adj_info, minibatch.deg, layer_infos=layer_infos, aggregator_type="meanpool", model_size=FLAGS.model_size, sigmoid_loss=FLAGS.sigmoid, mse_loss=FLAGS.regression, identity_dim=FLAGS.identity_dim, logging=True) else: raise Exception('Error: model name unrecognized.') config = tf.compat.v1.ConfigProto( log_device_placement=FLAGS.log_device_placement) config.gpu_options.allow_growth = True #config.gpu_options.per_process_gpu_memory_fraction = GPU_MEM_FRACTION config.allow_soft_placement = True # Initialize WandB experiment wandb.init(project='GraphSAGE_trial', save_code=True, tags=['supervised']) wandb.config.update(flags.FLAGS) # Initialize session sess = tf.compat.v1.Session(config=config) merged = tf.compat.v1.summary.merge_all() summary_writer = tf.compat.v1.summary.FileWriter(log_dir(), sess.graph) # Init variables sess.run(tf.compat.v1.global_variables_initializer(), feed_dict={adj_info_ph: minibatch.adj}) # Init saver saver = tf.compat.v1.train.Saver(max_to_keep=8, keep_checkpoint_every_n_hours=1) # Train model total_steps = 0 avg_time = 0.0 epoch_val_costs = [] train_adj_info = tf.compat.v1.assign(adj_info, minibatch.adj) val_adj_info = tf.compat.v1.assign(adj_info, minibatch.test_adj) for epoch in range(FLAGS.epochs): minibatch.shuffle() iter = 0 print('Epoch: %04d' % (epoch + 1)) epoch_val_costs.append(0) while not minibatch.end(): # Construct feed dictionary feed_dict, labels = minibatch.next_minibatch_feed_dict() feed_dict.update({placeholders['dropout']: FLAGS.dropout}) t = time.time() # Training step outs = sess.run([merged, model.opt_op, model.loss, model.preds], feed_dict=feed_dict) train_cost = outs[2] # Validation if iter % FLAGS.validate_iter == 0: sess.run(val_adj_info.op) if FLAGS.classification: if FLAGS.validate_batch_size == -1: val_cost, val_f1_mic, val_f1_mac, duration = incremental_evaluate( sess, model, minibatch, FLAGS.batch_size) else: val_cost, val_f1_mic, val_f1_mac, duration = evaluate( sess, model, minibatch, FLAGS.validate_batch_size) elif FLAGS.regression: if FLAGS.validate_batch_size == -1: val_cost, val_rmse, val_mape, duration = incremental_evaluate_regress( sess, model, minibatch, FLAGS.batch_size, scaler=target_scaler) else: val_cost, val_rmse, val_mape, duration = evaluate_regress( sess, model, minibatch, FLAGS.validate_batch_size, scaler=target_scaler) else: Exception( "Either classification or regression set must be to True." ) sess.run(train_adj_info.op) epoch_val_costs[-1] += val_cost if total_steps % FLAGS.print_every == 0: summary_writer.add_summary(outs[0], total_steps) # Print results avg_time = (avg_time * total_steps + time.time() - t) / (total_steps + 1) if total_steps % FLAGS.print_every == 0: if FLAGS.classification: train_f1_mic, train_f1_mac = calc_f1(labels, outs[-1]) print("[%03d/%03d]" % (epoch + 1, FLAGS.epochs), "Iter:", '%04d' % iter, "train_loss =", "{:.5f}".format(train_cost), "train_f1_mic =", "{:.5f}".format(train_f1_mic), "train_f1_mac =", "{:.5f}".format(train_f1_mac), "val_loss =", "{:.5f}".format(val_cost), "val_f1_mic =", "{:.5f}".format(val_f1_mic), "val_f1_mac =", "{:.5f}".format(val_f1_mac), "time =", "{:.5f}".format(avg_time)) elif FLAGS.regression: train_rmse = calc_rmse(labels, outs[-1]) train_mape = calc_mape(labels, outs[-1]) print("[%03d/%03d]" % (epoch + 1, FLAGS.epochs), "Iter:", '%04d' % iter, "train_loss =", "{:.5f}".format(train_cost), "train_rmse =", "{:.5f}".format(train_rmse), "train_mape =", "{:.5f}".format(train_mape), "val_loss =", "{:.5f}".format(val_cost), "val_rmse =", "{:.5f}".format(val_rmse), "val_mape =", "{:.5f}".format(val_mape), "time =", "{:.5f}".format(avg_time)) # W&B Logging if FLAGS.wandb_log and iter % FLAGS.wandb_log_iter == 0: if FLAGS.classification: train_f1_mic, train_f1_mac = calc_f1(labels, outs[-1]) wandb.log({'train_loss': train_cost, 'epoch': epoch}) wandb.log({'train_f1_mic': train_f1_mic, 'epoch': epoch}) wandb.log({'train_f1_mac': train_f1_mac, 'epoch': epoch}) wandb.log({'val_cost': val_cost, 'epoch': epoch}) wandb.log({'val_f1_mic': val_f1_mic, 'epoch': epoch}) wandb.log({'val_f1_mac': val_f1_mac, 'epoch': epoch}) wandb.log({'time': avg_time, 'epoch': epoch}) elif FLAGS.regression: train_rmse = calc_rmse(labels, outs[-1]) train_mape = calc_mape(labels, outs[-1]) wandb.log({'train_loss': train_cost, 'epoch': epoch}) wandb.log({'train_rmse': train_rmse, 'epoch': epoch}) wandb.log({'train_mape': train_mape, 'epoch': epoch}) wandb.log({'val_cost': val_cost, 'epoch': epoch}) wandb.log({'val_rmse': val_rmse, 'epoch': epoch}) wandb.log({'val_mape': val_mape, 'epoch': epoch}) wandb.log({'time': avg_time, 'epoch': epoch}) iter += 1 total_steps += 1 if total_steps > FLAGS.max_total_steps: break # Save Model checkpoints if FLAGS.save_checkpoints and epoch % FLAGS.save_checkpoints_epoch == 0 and epoch != 0: # saver.save(sess, log_dir() + 'model', global_step=1000) print('Save model checkpoint:', wandb.run.dir, iter, total_steps, epoch) saver.save( sess, os.path.join(wandb.run.dir, "model-" + str(epoch + 1) + ".ckpt")) if total_steps > FLAGS.max_total_steps: break print("Optimization Finished!") sess.run(val_adj_info.op) if FLAGS.classification: val_cost, val_f1_mic, val_f1_mac, duration = incremental_evaluate( sess, model, minibatch, FLAGS.batch_size) print("Full validation stats:", "loss=", "{:.5f}".format(val_cost), "f1_micro=", "{:.5f}".format(val_f1_mic), "f1_macro=", "{:.5f}".format(val_f1_mac), "time=", "{:.5f}".format(duration)) with open(log_dir() + "val_stats.txt", "w") as fp: fp.write("loss={:.5f} f1_micro={:.5f} f1_macro={:.5f} time={:.5f}". format(val_cost, val_f1_mic, val_f1_mac, duration)) print("Writing test set stats to file (don't peak!)") val_cost, val_f1_mic, val_f1_mac, duration = incremental_evaluate( sess, model, minibatch, FLAGS.batch_size, test=True) with open(log_dir() + "test_stats.txt", "w") as fp: fp.write("loss={:.5f} f1_micro={:.5f} f1_macro={:.5f}".format( val_cost, val_f1_mic, val_f1_mac)) elif FLAGS.regression: val_cost, val_rmse, val_mape, duration = incremental_evaluate_regress( sess, model, minibatch, FLAGS.batch_size, scaler=target_scaler) print("Full validation stats:", "loss=", "{:.5f}".format(val_cost), "rmse=", "{:.5f}".format(val_rmse), "mape=", "{:.5f}".format(val_mape), "time=", "{:.5f}".format(duration)) with open(log_dir() + "val_stats.txt", "w") as fp: fp.write("loss={:.5f} rmse={:.5f} time={:.5f}".format( val_cost, val_rmse, duration)) print("Writing test set stats to file (don't peak!)") val_cost, val_rmse, val_mape, duration = incremental_evaluate_regress( sess, model, minibatch, FLAGS.batch_size, test=True, scaler=target_scaler) with open(log_dir() + "test_stats.txt", "w") as fp: fp.write("loss={:.5f} rmse={:.5f} mape={:.5f}".format( val_cost, val_rmse, val_mape))