def test_unsupervised_node_classification(): sys.argv = [sys.argv[0], "-m", "prone", "-dt", "ppi"] parser = options.get_training_parser() args, _ = parser.parse_known_args() args = options.parse_args_and_arch(parser, args) print(args) assert args.model[0] == "prone" assert args.dataset[0] == "ppi"
def test_unsupervised_graph_classification(): sys.argv = [sys.argv[0], "-m", "infograph", "-dt", "mutag"] parser = options.get_training_parser() args, _ = parser.parse_known_args() args = options.parse_args_and_arch(parser, args) print(args) assert args.model[0] == "infograph" assert args.dataset[0] == "mutag"
def test_training_options(): sys.argv = [sys.argv[0], "-m", "gcn", "-dt", "cora"] parser = options.get_training_parser() args, _ = parser.parse_known_args() args = options.parse_args_and_arch(parser, args) print(args) assert args.model[0] == "gcn" assert args.dataset[0] == "cora"
def test_attributed_graph_clustering(): sys.argv = [sys.argv[0], "-m", "daegc", "-dt", "cora"] parser = options.get_training_parser() args, _ = parser.parse_known_args() args = options.parse_args_and_arch(parser, args) print(args) assert args.model[0] == "daegc" assert args.dataset[0] == "cora" assert args.num_clusters == 7
def test_multiplex_link_prediction(): sys.argv = [sys.argv[0], "-m", "gatne", "-dt", "amazon"] parser = options.get_training_parser() args, _ = parser.parse_known_args() args = options.parse_args_and_arch(parser, args) print(args) assert args.model[0] == "gatne" assert args.dataset[0] == "amazon" assert args.eval_type == "all"
def test_link_prediction(): sys.argv = [sys.argv[0], "-t", "link_prediction", "-m", "prone", "-dt", "ppi"] parser = options.get_training_parser() args, _ = parser.parse_known_args() args = options.parse_args_and_arch(parser, args) print(args) assert args.task == "link_prediction" assert args.model[0] == "prone" assert args.dataset[0] == "ppi" assert args.evaluate_interval == 30
def test_graph_classification(): sys.argv = [sys.argv[0], "-t", "graph_classification", "-m", "gin", "-dt", "mutag"] parser = options.get_training_parser() args, _ = parser.parse_known_args() args = options.parse_args_and_arch(parser, args) print(args) assert args.task == "graph_classification" assert args.model[0] == "gin" assert args.dataset[0] == "mutag" assert args.degree_feature is False
def test_unsupervised_graph_classification(): sys.argv = [sys.argv[0], "-t", "unsupervised_graph_classification", "-m", "infograph", "-dt", "mutag"] parser = options.get_training_parser() args, _ = parser.parse_known_args() args = options.parse_args_and_arch(parser, args) print(args) assert args.task == "unsupervised_graph_classification" assert args.model[0] == "infograph" assert args.dataset[0] == "mutag" assert args.num_shuffle == 10 assert args.degree_feature is False
def test_link_prediction(): sys.argv = [sys.argv[0], "-m", "prone", "-dt", "ppi"] sys.argv += [ "--mw", "embedding_link_prediction_mw", "--dw", "embedding_link_prediction_dw" ] parser = options.get_training_parser() args, _ = parser.parse_known_args() args = options.parse_args_and_arch(parser, args) print(args) assert args.model[0] == "prone" assert args.dataset[0] == "ppi" assert args.mw == "embedding_link_prediction_mw" assert args.dw == "embedding_link_prediction_dw"
def gen_variants(**items): Variant = namedtuple("Variant", items.keys()) return itertools.starmap(Variant, itertools.product(*items.values())) def getpid(_): # HACK to get different pids time.sleep(1) return mp.current_process().pid if __name__ == "__main__": # Magic for making multiprocessing work for PyTorch mp.set_start_method("spawn", force=True) parser = options.get_training_parser() args, _ = parser.parse_known_args() args = options.parse_args_and_arch(parser, args) # Make sure datasets are downloaded first datasets = args.dataset for dataset in datasets: args.dataset = dataset _ = build_dataset(args) args.dataset = datasets print(args) variants = list( gen_variants(dataset=args.dataset, model=args.model, seed=args.seed)) device_ids = args.device_id