def test_run_single_graphgym(auto_resume, skip_train_eval, use_trivial_metric):
    Args = namedtuple('Args', ['cfg_file', 'opts'])
    root = osp.join(osp.dirname(osp.realpath(__file__)))
    args = Args(osp.join(root, 'example_node.yml'), [])

    load_cfg(cfg, args)
    cfg.out_dir = osp.join('/', 'tmp', str(random.randrange(sys.maxsize)))
    cfg.run_dir = osp.join('/', 'tmp', str(random.randrange(sys.maxsize)))
    cfg.dataset.dir = osp.join('/', 'tmp', 'pyg_test_datasets', 'Planetoid')
    cfg.train.auto_resume = auto_resume

    set_out_dir(cfg.out_dir, args.cfg_file)
    dump_cfg(cfg)
    set_printing()

    seed_everything(cfg.seed)
    auto_select_device()
    set_run_dir(cfg.out_dir)

    cfg.train.skip_train_eval = skip_train_eval
    cfg.train.enable_ckpt = use_trivial_metric and skip_train_eval
    if use_trivial_metric:
        if 'trivial' not in register.metric_dict:
            register.register_metric('trivial', trivial_metric)
        global num_trivial_metric_calls
        num_trivial_metric_calls = 0
        cfg.metric_best = 'trivial'
        cfg.custom_metrics = ['trivial']
    else:
        cfg.metric_best = 'auto'
        cfg.custom_metrics = []

    datamodule = GraphGymDataModule()
    assert len(datamodule.loaders) == 3

    model = create_model()
    assert isinstance(model, torch.nn.Module)
    assert isinstance(model.encoder, FeatureEncoder)
    assert isinstance(model.mp, GNNStackStage)
    assert isinstance(model.post_mp, GNNNodeHead)
    assert len(list(model.pre_mp.children())) == cfg.gnn.layers_pre_mp

    optimizer, scheduler = model.configure_optimizers()
    assert isinstance(optimizer[0], torch.optim.Adam)
    assert isinstance(scheduler[0], torch.optim.lr_scheduler.CosineAnnealingLR)

    cfg.params = params_count(model)
    assert cfg.params == 23883

    train(model,
          datamodule,
          logger=True,
          trainer_config={"enable_progress_bar": False})

    assert osp.isdir(get_ckpt_dir()) is cfg.train.enable_ckpt

    agg_runs(cfg.out_dir, cfg.metric_best)

    shutil.rmtree(cfg.out_dir)
Beispiel #2
0
def test_run_single_graphgym():
    Args = namedtuple('Args', ['cfg_file', 'opts'])
    root = osp.join(osp.dirname(osp.realpath(__file__)))
    args = Args(osp.join(root, 'example_node.yml'), [])

    load_cfg(cfg, args)
    cfg.out_dir = osp.join('/', 'tmp', str(random.randrange(sys.maxsize)))
    cfg.run_dir = osp.join('/', 'tmp', str(random.randrange(sys.maxsize)))
    cfg.dataset.dir = osp.join('/', 'tmp', str(random.randrange(sys.maxsize)))
    dump_cfg(cfg)
    set_printing()

    seed_everything(cfg.seed)
    auto_select_device()
    set_run_dir(cfg.out_dir, args.cfg_file)

    loaders = create_loader()
    assert len(loaders) == 3

    loggers = create_logger()
    assert len(loggers) == 3

    model = create_model()
    assert isinstance(model, torch.nn.Module)
    assert isinstance(model.encoder, FeatureEncoder)
    assert isinstance(model.mp, GNNStackStage)
    assert isinstance(model.post_mp, GNNNodeHead)
    assert len(list(model.pre_mp.children())) == cfg.gnn.layers_pre_mp

    optimizer_config = OptimizerConfig(optimizer=cfg.optim.optimizer,
                                       base_lr=cfg.optim.base_lr,
                                       weight_decay=cfg.optim.weight_decay,
                                       momentum=cfg.optim.momentum)
    optimizer = create_optimizer(model.parameters(), optimizer_config)
    assert isinstance(optimizer, torch.optim.Adam)

    scheduler_config = SchedulerConfig(scheduler=cfg.optim.scheduler,
                                       steps=cfg.optim.steps,
                                       lr_decay=cfg.optim.lr_decay,
                                       max_epoch=cfg.optim.max_epoch)
    scheduler = create_scheduler(optimizer, scheduler_config)
    assert isinstance(scheduler, torch.optim.lr_scheduler.CosineAnnealingLR)

    cfg.params = params_count(model)
    assert cfg.params == 23880

    train(loggers, loaders, model, optimizer, scheduler)

    agg_runs(set_agg_dir(cfg.out_dir, args.cfg_file), cfg.metric_best)

    shutil.rmtree(cfg.out_dir)
    shutil.rmtree(cfg.dataset.dir)
Beispiel #3
0
def test_run_single_graphgym(skip_train_eval, use_trivial_metric):
    Args = namedtuple('Args', ['cfg_file', 'opts'])
    root = osp.join(osp.dirname(osp.realpath(__file__)))
    args = Args(osp.join(root, 'example_node.yml'), [])

    load_cfg(cfg, args)
    cfg.out_dir = osp.join('/', 'tmp', str(random.randrange(sys.maxsize)))
    cfg.run_dir = osp.join('/', 'tmp', str(random.randrange(sys.maxsize)))
    cfg.dataset.dir = osp.join('/', 'tmp', str(random.randrange(sys.maxsize)))
    dump_cfg(cfg)
    set_printing()

    seed_everything(cfg.seed)
    auto_select_device()
    set_run_dir(cfg.out_dir, args.cfg_file)

    cfg.train.skip_train_eval = skip_train_eval
    cfg.train.enable_ckpt = use_trivial_metric and skip_train_eval
    if use_trivial_metric:
        if 'trivial' not in register.metric_dict:
            register.register_metric('trivial', trivial_metric)
        global num_trivial_metric_calls
        num_trivial_metric_calls = 0
        cfg.metric_best = 'trivial'
        cfg.custom_metrics = ['trivial']
    else:
        cfg.metric_best = 'auto'
        cfg.custom_metrics = []

    loaders = create_loader()
    assert len(loaders) == 3

    loggers = create_logger()
    assert len(loggers) == 3

    model = create_model()
    assert isinstance(model, torch.nn.Module)
    assert isinstance(model.encoder, FeatureEncoder)
    assert isinstance(model.mp, GNNStackStage)
    assert isinstance(model.post_mp, GNNNodeHead)
    assert len(list(model.pre_mp.children())) == cfg.gnn.layers_pre_mp

    optimizer = create_optimizer(model.parameters(), cfg.optim)
    assert isinstance(optimizer, torch.optim.Adam)

    scheduler = create_scheduler(optimizer, cfg.optim)
    assert isinstance(scheduler, torch.optim.lr_scheduler.CosineAnnealingLR)

    cfg.params = params_count(model)
    assert cfg.params == 23880

    train(loggers, loaders, model, optimizer, scheduler)

    if use_trivial_metric:
        # 6 total epochs, 4 eval epochs, 3 splits (1 training split)
        assert num_trivial_metric_calls == 12 if skip_train_eval else 14

    assert osp.isdir(get_ckpt_dir()) is cfg.train.enable_ckpt

    agg_runs(set_agg_dir(cfg.out_dir, args.cfg_file), cfg.metric_best)

    shutil.rmtree(cfg.out_dir)
    shutil.rmtree(cfg.dataset.dir)