예제 #1
0
파일: main.py 프로젝트: lundu28/GNE
def main():

    parser = argparse.ArgumentParser(
                formatter_class = argparse.RawTextHelpFormatter)
    parser.add_argument('--operation', type = str, default = "all", help = "[all | extract_tree | train | metric]")
    parser.add_argument('--conf', type = str, default = "default")
    parser.add_argument('--metric_input', type = str, default = "new_train_res")
    parser.add_argument('--train_output', type = str, default = str(int(time.time() * 1000.0)))
    parser.add_argument('--metric_output', type = str, default = str(int(time.time() * 1000.0)))

    args = parser.parse_args()
    params = dh.load_json_file(os.path.join(CONF_PATH, args.conf + ".json"))
    params["metric_input"] = os.path.join(RES_PATH, args.metric_input)
    params["train_output"] = os.path.join(RES_PATH, "train_res_" + args.train_output)
    params["metric_output"] = os.path.join(RES_PATH, "metric_res_" + args.metric_output)


    if args.operation == "all":
        train_model(params)
        metric(params)
    elif args.operation == "extract_tree":
        extract_tree(params)
    elif args.operation == "train":
        train_model(params)
    elif args.operation == "metric":
        metric(params)
    else:
        print "Not Support!"
예제 #2
0
def loop(params, G, embeddings, weights, metric, output_path, draw):
    embeddings_path = os.path.join(RES_PATH, params["embeddings_path"])
    dynamic_embeddings = dh.load_json_file(embeddings_path)
    for items in dynamic_embeddings:
        embeddings = np.array(items["embeddings"])
        metric(embeddings)
        draw(embeddings)
예제 #3
0
def main():
    parser = argparse.ArgumentParser(
        formatter_class=argparse.RawTextHelpFormatter)
    parser.add_argument('--operation',
                        type=str,
                        default="all",
                        help="[all | train | metric | draw]")
    parser.add_argument('--conf', type=str, default="default")
    parser.add_argument('--iteration', type=int, default=10001)
    parser.add_argument('--model', type=str, default="model_simple")
    args = parser.parse_args()
    params = dh.load_json_file(
        os.path.join(SINGLE_CONF_PATH, args.conf + ".json"))
    params["iteration"] = args.iteration
    params["model"] = args.model

    if args.operation == "all":
        train_model(params)
        metric(params)
    elif args.operation == "train":
        train_model(params)
    elif args.operation == "metric":
        metric(params)
    elif args.operation == "draw":
        pass
    else:
        print "Not Support!"
예제 #4
0
def main_old():

    parser = argparse.ArgumentParser(
        formatter_class=argparse.RawTextHelpFormatter)
    parser.add_argument('--operation',
                        type=str,
                        default="all",
                        help="[all | init | train | metric | draw]")
    parser.add_argument('--conf', type=str, default="default")
    args = parser.parse_args()
    params = dh.load_json_file(os.path.join(CONF_PATH, args.conf + ".json"))

    metric_path_pre = os.path.join(RES_PATH, args.conf)
    if os.path.exists(metric_path_pre) == False:
        os.mkdir(metric_path_pre)
    output_path = os.path.join(metric_path_pre, dh.get_time_str())
    metric_path = output_path + "_metric"

    def metric(embeddings):
        if "metrics" not in params:
            return
        for metric in params["metrics"]:
            res = getattr(Metric, metric["func"])(embeddings, metric)
            dh.append_to_file(metric_path, str(res) + "\n")
            print res

    dh.symlink(metric_path, os.path.join(metric_path_pre, "new_metric"))

    if "drawers" in params:
        draw_path = output_path + "_draw"
        if os.path.exists(draw_path) == False:
            os.mkdir(draw_path)
    draw_cnt = [0]

    def draw(embeddings):
        if "drawers" not in params:
            return
        for drawer in params['drawers']:
            getattr(Metric, drawer["func"])(embeddings, drawer, draw_path,
                                            draw_cnt[0])
        draw_cnt[0] += 1

    if args.operation == "all":
        G, embeddings, weights = __import__("init." + params["init"]["func"],
                                            fromlist=["init"]).init(
                                                params["init"], metric,
                                                output_path, draw)
        __import__("dynamic_loop." + params["main_loop"]["func"],
                   fromlist=["dynamic_loop"]).loop(params["main_loop"], G,
                                                   embeddings, weights, metric,
                                                   output_path, draw)
    elif args.operation == "init":
        G, embeddings, weights = __import__("init." + params["init"]["func"],
                                            fromlist=["init"]).init(
                                                params["init"], metric,
                                                output_path, draw)
    elif args.operation == "draw":
        pass
    else:
        print "Not Support!"
예제 #5
0
파일: multi_test.py 프로젝트: luke28/NENIF
def main():
    parser = argparse.ArgumentParser(
        formatter_class=argparse.RawTextHelpFormatter)
    parser.add_argument('--conf', type=str, default="default")
    args = parser.parse_args()
    params = dh.load_json_file(
        os.path.join(MULTI_CONF_PATH, args.conf + ".json"))

    out_path = os.path.join(RES_PATH,
                            "multi_res_" + str(int(time.time() * 1000.0)))
    single_params = {}
    for item in params.items():
        if item[0] == "models":
            continue
        single_params[item[0]] = item[1]
    for m in params["models"]:
        for it in m.items():
            if it[0] == "traversal":
                continue
            else:
                single_params[it[0]] = it[1]
        if "traversal" not in m:
            tmp = []
        else:
            tmp = [item for item in m["traversal"].items()]
        with open(out_path, "a") as f:
            dfs(tmp, single_params, f)

    try:
        os.symlink(out_path, os.path.join(RES_PATH, "MultiRes"))
    except OSError:
        os.remove(os.path.join(RES_PATH, "MultiRes"))
        os.symlink(out_path, os.path.join(RES_PATH, "MultiRes"))
예제 #6
0
def init(params, metric, output_path, draw):
    embeddings_path = os.path.join(RES_PATH, params["embeddings_path"])
    dic = dh.load_json_file(embeddings_path)
    embeddings = np.array(dic["embeddings"])
    metric(embeddings)
    draw(embeddings)
    return None, None, None
def main():
    parser = argparse.ArgumentParser(
        formatter_class=argparse.RawTextHelpFormatter)
    parser.add_argument('--operation',
                        type=str,
                        default="all",
                        help="[all | init | train]")
    parser.add_argument('--conf', type=str, default="amherst0.25")
    args = parser.parse_args()
    params = dh.load_json_file(os.path.join(CONF_PATH, args.conf + ".json"))
    fw = open('my_embedding.txt', 'w')
    metric_path_pre = os.path.join(RES_PATH, args.conf)
    if os.path.exists(metric_path_pre) == False:
        os.mkdir(metric_path_pre)
    output_path = os.path.join(metric_path_pre, dh.get_time_str())
    print(output_path)
    metric_path = output_path + "_metric"

    def metric(embeddings):
        if "metrics" not in params:
            return
        for metric in params["metrics"]:
            print("[] Start node classification...")
            res = getattr(Metric, metric["func"])(embeddings, metric)
            dh.append_to_file(metric_path, str(res) + "\n")
            print("[+] Metric: " + str(res))

    dh.symlink(metric_path, os.path.join(metric_path_pre, "new_metric"))

    if args.operation == "all":
        G, embeddings, weights = __import__("init." + params["init"]["func"],
                                            fromlist=["init"]).init(
                                                params["init"], metric,
                                                output_path)
        __import__("dynamic_loop." + params["main_loop"]["func"],
                   fromlist=["dynamic_loop"]).loop(params["main_loop"], G,
                                                   embeddings, weights, metric,
                                                   output_path)
        print(embeddings)
    elif args.operation == "init":
        G, embeddings, weights = __import__("init." + params["init"]["func"],
                                            fromlist=["init"]).init(
                                                params["init"], metric,
                                                output_path)
    else:
        print("Not Support!")