コード例 #1
0
ファイル: test_options.py プロジェクト: xssstory/cogdl
def test_display_options():
    sys.argv = [sys.argv[0], "-dt", "cora"]
    parser = options.get_display_data_parser()
    args = parser.parse_args()
    print(args)

    assert args.dataset[0] == "cora"
    assert args.depth > 0
コード例 #2
0
ファイル: display_data.py プロジェクト: znsoftm/cogdl
        for node, index in node_index.items():
            G.nodes[node]["color"] = cmap[index]
            G.nodes[node]["size"] = (max_index - index) * 50

        fig, ax = plt.subplots()
        plot_network(G.subgraph(list(node_set)), node_style=use_attributes())
        plt.savefig(pic_file)
        print(f"Sampled ego network saved to {pic_file} .")


if __name__ == "__main__":
    """
    parser = argparse.ArgumentParser()
    parser.add_argument('--seed', '-s', type=int, default=0, help='random seed')
    parser.add_argument('--depth', '-d', type=int, default=3, help='neighborhood depth')
    parser.add_argument('--name', '-n', type=str, default='Cora', help='dataset name')
    parser.add_argument('--file', '-f', type=str, default='graph.jpg', help='saved file name')
    args = parser.parse_args()
    """
    parser = options.get_display_data_parser()
    args = parser.parse_args()

    if isinstance(args.seed, list):
        args.seed = args.seed[0]

    random.seed(args.seed)
    np.random.seed(args.seed)

    plot_graph(args)