Пример #1
0
def test_unsupervised_node_classification():
    sys.argv = [sys.argv[0], "-m", "prone", "-dt", "ppi"]
    parser = options.get_training_parser()
    args, _ = parser.parse_known_args()
    args = options.parse_args_and_arch(parser, args)
    print(args)

    assert args.model[0] == "prone"
    assert args.dataset[0] == "ppi"
Пример #2
0
def test_unsupervised_graph_classification():
    sys.argv = [sys.argv[0], "-m", "infograph", "-dt", "mutag"]
    parser = options.get_training_parser()
    args, _ = parser.parse_known_args()
    args = options.parse_args_and_arch(parser, args)
    print(args)

    assert args.model[0] == "infograph"
    assert args.dataset[0] == "mutag"
Пример #3
0
def test_training_options():
    sys.argv = [sys.argv[0], "-m", "gcn", "-dt", "cora"]
    parser = options.get_training_parser()
    args, _ = parser.parse_known_args()
    args = options.parse_args_and_arch(parser, args)
    print(args)

    assert args.model[0] == "gcn"
    assert args.dataset[0] == "cora"
Пример #4
0
def test_attributed_graph_clustering():
    sys.argv = [sys.argv[0], "-m", "daegc", "-dt", "cora"]
    parser = options.get_training_parser()
    args, _ = parser.parse_known_args()
    args = options.parse_args_and_arch(parser, args)
    print(args)

    assert args.model[0] == "daegc"
    assert args.dataset[0] == "cora"
    assert args.num_clusters == 7
Пример #5
0
def test_multiplex_link_prediction():
    sys.argv = [sys.argv[0], "-m", "gatne", "-dt", "amazon"]
    parser = options.get_training_parser()
    args, _ = parser.parse_known_args()
    args = options.parse_args_and_arch(parser, args)
    print(args)

    assert args.model[0] == "gatne"
    assert args.dataset[0] == "amazon"
    assert args.eval_type == "all"
Пример #6
0
def test_link_prediction():
    sys.argv = [sys.argv[0], "-t", "link_prediction", "-m", "prone", "-dt", "ppi"]
    parser = options.get_training_parser()
    args, _ = parser.parse_known_args()
    args = options.parse_args_and_arch(parser, args)
    print(args)

    assert args.task == "link_prediction"
    assert args.model[0] == "prone"
    assert args.dataset[0] == "ppi"
    assert args.evaluate_interval == 30
Пример #7
0
def test_graph_classification():
    sys.argv = [sys.argv[0], "-t", "graph_classification", "-m", "gin", "-dt", "mutag"]
    parser = options.get_training_parser()
    args, _ = parser.parse_known_args()
    args = options.parse_args_and_arch(parser, args)
    print(args)

    assert args.task == "graph_classification"
    assert args.model[0] == "gin"
    assert args.dataset[0] == "mutag"
    assert args.degree_feature is False
Пример #8
0
def test_unsupervised_graph_classification():
    sys.argv = [sys.argv[0], "-t", "unsupervised_graph_classification", "-m", "infograph", "-dt", "mutag"]
    parser = options.get_training_parser()
    args, _ = parser.parse_known_args()
    args = options.parse_args_and_arch(parser, args)
    print(args)

    assert args.task == "unsupervised_graph_classification"
    assert args.model[0] == "infograph"
    assert args.dataset[0] == "mutag"
    assert args.num_shuffle == 10
    assert args.degree_feature is False
Пример #9
0
def test_link_prediction():
    sys.argv = [sys.argv[0], "-m", "prone", "-dt", "ppi"]
    sys.argv += [
        "--mw", "embedding_link_prediction_mw", "--dw",
        "embedding_link_prediction_dw"
    ]
    parser = options.get_training_parser()
    args, _ = parser.parse_known_args()
    args = options.parse_args_and_arch(parser, args)
    print(args)

    assert args.model[0] == "prone"
    assert args.dataset[0] == "ppi"
    assert args.mw == "embedding_link_prediction_mw"
    assert args.dw == "embedding_link_prediction_dw"
Пример #10
0
def gen_variants(**items):
    Variant = namedtuple("Variant", items.keys())
    return itertools.starmap(Variant, itertools.product(*items.values()))


def getpid(_):
    # HACK to get different pids
    time.sleep(1)
    return mp.current_process().pid


if __name__ == "__main__":
    # Magic for making multiprocessing work for PyTorch
    mp.set_start_method("spawn", force=True)

    parser = options.get_training_parser()
    args, _ = parser.parse_known_args()
    args = options.parse_args_and_arch(parser, args)

    # Make sure datasets are downloaded first
    datasets = args.dataset
    for dataset in datasets:
        args.dataset = dataset
        _ = build_dataset(args)
    args.dataset = datasets

    print(args)
    variants = list(
        gen_variants(dataset=args.dataset, model=args.model, seed=args.seed))

    device_ids = args.device_id