def test_check_task_dataset_model_match(): variants = list( gen_variants(dataset=["cora"], model=["gcn", "gat"], seed=[1, 2])) variants.append( namedtuple("Variant", ["dataset", "model", "seed"])(dataset="cora", model="deepwalk", seed=1)) variants = check_task_dataset_model_match("node_classification", variants) assert len(variants) == 4
def test_gen_variants(): variants = list( gen_variants(dataset=["cora"], model=["gcn", "gat"], seed=[1, 2])) assert len(variants) == 4
mp.set_start_method("spawn", force=True) parser = options.get_training_parser() args, _ = parser.parse_known_args() args = options.parse_args_and_arch(parser, args) # Make sure datasets are downloaded first datasets = args.dataset for dataset in datasets: args.dataset = dataset _ = build_dataset(args) args.dataset = datasets print(args) variants = list( gen_variants(dataset=args.dataset, model=args.model, seed=args.seed)) device_ids = args.device_id if args.cpu: num_workers = 1 else: num_workers = len(device_ids) print("num_workers", num_workers) results_dict = defaultdict(list) with mp.Pool(processes=num_workers) as pool: # Map process to cuda device pids = pool.map(getpid, range(num_workers)) pid_to_cuda = dict(zip(pids, device_ids)) # yield all variants