def test_run_single_graphgym(): Args = namedtuple('Args', ['cfg_file', 'opts']) root = osp.join(osp.dirname(osp.realpath(__file__))) args = Args(osp.join(root, 'example_node.yml'), []) load_cfg(cfg, args) cfg.out_dir = osp.join('/', 'tmp', str(random.randrange(sys.maxsize))) cfg.run_dir = osp.join('/', 'tmp', str(random.randrange(sys.maxsize))) cfg.dataset.dir = osp.join('/', 'tmp', str(random.randrange(sys.maxsize))) dump_cfg(cfg) set_printing() seed_everything(cfg.seed) auto_select_device() set_run_dir(cfg.out_dir, args.cfg_file) loaders = create_loader() assert len(loaders) == 3 loggers = create_logger() assert len(loggers) == 3 model = create_model() assert isinstance(model, torch.nn.Module) assert isinstance(model.encoder, FeatureEncoder) assert isinstance(model.mp, GNNStackStage) assert isinstance(model.post_mp, GNNNodeHead) assert len(list(model.pre_mp.children())) == cfg.gnn.layers_pre_mp optimizer = create_optimizer(model.parameters(), cfg.optim) assert isinstance(optimizer, torch.optim.Adam) scheduler = create_scheduler(optimizer, cfg.optim) assert isinstance(scheduler, torch.optim.lr_scheduler.CosineAnnealingLR) cfg.params = params_count(model) assert cfg.params == 23880 train(loggers, loaders, model, optimizer, scheduler) agg_runs(set_agg_dir(cfg.out_dir, args.cfg_file), cfg.metric_best) shutil.rmtree(cfg.out_dir) shutil.rmtree(cfg.dataset.dir)
def test_run_single_graphgym(skip_train_eval, use_trivial_metric): Args = namedtuple('Args', ['cfg_file', 'opts']) root = osp.join(osp.dirname(osp.realpath(__file__))) args = Args(osp.join(root, 'example_node.yml'), []) load_cfg(cfg, args) cfg.out_dir = osp.join('/', 'tmp', str(random.randrange(sys.maxsize))) cfg.run_dir = osp.join('/', 'tmp', str(random.randrange(sys.maxsize))) cfg.dataset.dir = osp.join('/', 'tmp', str(random.randrange(sys.maxsize))) dump_cfg(cfg) set_printing() seed_everything(cfg.seed) auto_select_device() set_run_dir(cfg.out_dir, args.cfg_file) cfg.train.skip_train_eval = skip_train_eval cfg.train.enable_ckpt = use_trivial_metric and skip_train_eval if use_trivial_metric: if 'trivial' not in register.metric_dict: register.register_metric('trivial', trivial_metric) global num_trivial_metric_calls num_trivial_metric_calls = 0 cfg.metric_best = 'trivial' cfg.custom_metrics = ['trivial'] else: cfg.metric_best = 'auto' cfg.custom_metrics = [] loaders = create_loader() assert len(loaders) == 3 loggers = create_logger() assert len(loggers) == 3 model = create_model() assert isinstance(model, torch.nn.Module) assert isinstance(model.encoder, FeatureEncoder) assert isinstance(model.mp, GNNStackStage) assert isinstance(model.post_mp, GNNNodeHead) assert len(list(model.pre_mp.children())) == cfg.gnn.layers_pre_mp optimizer = create_optimizer(model.parameters(), cfg.optim) assert isinstance(optimizer, torch.optim.Adam) scheduler = create_scheduler(optimizer, cfg.optim) assert isinstance(scheduler, torch.optim.lr_scheduler.CosineAnnealingLR) cfg.params = params_count(model) assert cfg.params == 23880 train(loggers, loaders, model, optimizer, scheduler) if use_trivial_metric: # 6 total epochs, 4 eval epochs, 3 splits (1 training split) assert num_trivial_metric_calls == 12 if skip_train_eval else 14 assert osp.isdir(get_ckpt_dir()) is cfg.train.enable_ckpt agg_runs(set_agg_dir(cfg.out_dir, args.cfg_file), cfg.metric_best) shutil.rmtree(cfg.out_dir) shutil.rmtree(cfg.dataset.dir)
from torch_geometric.graphgym.utils.comp_budget import params_count from torch_geometric.graphgym.utils.device import auto_select_device if __name__ == '__main__': # Load cmd line args args = parse_args() # Load config file load_cfg(cfg, args) set_out_dir(cfg.out_dir, args.cfg_file) # Set Pytorch environment torch.set_num_threads(cfg.num_threads) dump_cfg(cfg) # Repeat for different random seeds for i in range(args.repeat): set_run_dir(cfg.out_dir) set_printing() # Set configurations for each run cfg.seed = cfg.seed + 1 seed_everything(cfg.seed) auto_select_device() # Set machine learning pipeline datamodule = GraphGymDataModule() model = create_model() # Print model info logging.info(model) logging.info(cfg) cfg.params = params_count(model) logging.info('Num parameters: %s', cfg.params) train(model, datamodule, logger=True) # Aggregate results from different seeds