コード例 #1
0
def test_train(tmpdir):
    import pytorch_lightning as pl

    load_cfg(cfg, args)
    cfg.out_dir = osp.join(tmpdir, str(random.randrange(sys.maxsize)))
    cfg.run_dir = osp.join(tmpdir, str(random.randrange(sys.maxsize)))
    cfg.dataset.dir = osp.join(tmpdir, 'pyg_test_datasets', 'Planetoid')

    set_out_dir(cfg.out_dir, args.cfg_file)
    dump_cfg(cfg)
    set_printing()

    seed_everything(cfg.seed)
    auto_select_device()
    set_run_dir(cfg.out_dir)

    loaders = create_loader()
    model = create_model()
    cfg.params = params_count(model)
    logger = LoggerCallback()
    trainer = pl.Trainer(max_epochs=1, max_steps=4, callbacks=logger,
                         log_every_n_steps=1)
    train_loader, val_loader = loaders[0], loaders[1]
    trainer.fit(model, train_loader, val_loader)

    shutil.rmtree(cfg.out_dir)
コード例 #2
0
def test_run_single_graphgym(auto_resume, skip_train_eval, use_trivial_metric):
    Args = namedtuple('Args', ['cfg_file', 'opts'])
    root = osp.join(osp.dirname(osp.realpath(__file__)))
    args = Args(osp.join(root, 'example_node.yml'), [])

    load_cfg(cfg, args)
    cfg.out_dir = osp.join('/', 'tmp', str(random.randrange(sys.maxsize)))
    cfg.run_dir = osp.join('/', 'tmp', str(random.randrange(sys.maxsize)))
    cfg.dataset.dir = osp.join('/', 'tmp', 'pyg_test_datasets', 'Planetoid')
    cfg.train.auto_resume = auto_resume

    set_out_dir(cfg.out_dir, args.cfg_file)
    dump_cfg(cfg)
    set_printing()

    seed_everything(cfg.seed)
    auto_select_device()
    set_run_dir(cfg.out_dir)

    cfg.train.skip_train_eval = skip_train_eval
    cfg.train.enable_ckpt = use_trivial_metric and skip_train_eval
    if use_trivial_metric:
        if 'trivial' not in register.metric_dict:
            register.register_metric('trivial', trivial_metric)
        global num_trivial_metric_calls
        num_trivial_metric_calls = 0
        cfg.metric_best = 'trivial'
        cfg.custom_metrics = ['trivial']
    else:
        cfg.metric_best = 'auto'
        cfg.custom_metrics = []

    datamodule = GraphGymDataModule()
    assert len(datamodule.loaders) == 3

    model = create_model()
    assert isinstance(model, torch.nn.Module)
    assert isinstance(model.encoder, FeatureEncoder)
    assert isinstance(model.mp, GNNStackStage)
    assert isinstance(model.post_mp, GNNNodeHead)
    assert len(list(model.pre_mp.children())) == cfg.gnn.layers_pre_mp

    optimizer, scheduler = model.configure_optimizers()
    assert isinstance(optimizer[0], torch.optim.Adam)
    assert isinstance(scheduler[0], torch.optim.lr_scheduler.CosineAnnealingLR)

    cfg.params = params_count(model)
    assert cfg.params == 23883

    train(model,
          datamodule,
          logger=True,
          trainer_config={"enable_progress_bar": False})

    assert osp.isdir(get_ckpt_dir()) is cfg.train.enable_ckpt

    agg_runs(cfg.out_dir, cfg.metric_best)

    shutil.rmtree(cfg.out_dir)
コード例 #3
0
def test_graphgym_module(tmpdir):
    import pytorch_lightning as pl

    load_cfg(cfg, args)
    cfg.out_dir = osp.join(tmpdir, str(random.randrange(sys.maxsize)))
    cfg.run_dir = osp.join(tmpdir, str(random.randrange(sys.maxsize)))
    cfg.dataset.dir = osp.join(tmpdir, 'pyg_test_datasets', 'Planetoid')

    set_out_dir(cfg.out_dir, args.cfg_file)
    dump_cfg(cfg)
    set_printing()

    seed_everything(cfg.seed)
    auto_select_device()
    set_run_dir(cfg.out_dir)

    loaders = create_loader()
    assert len(loaders) == 3

    model = create_model()
    assert isinstance(model, pl.LightningModule)

    optimizer, scheduler = model.configure_optimizers()
    assert isinstance(optimizer[0], torch.optim.Adam)
    assert isinstance(scheduler[0], torch.optim.lr_scheduler.CosineAnnealingLR)

    cfg.params = params_count(model)
    assert cfg.params == 23880

    keys = {"loss", "true", "pred_score", "step_end_time"}
    # test training step
    batch = next(iter(loaders[0]))
    batch.to(model.device)
    outputs = model.training_step(batch)
    assert keys == set(outputs.keys())
    assert isinstance(outputs["loss"], torch.Tensor)

    # test validation step
    batch = next(iter(loaders[1]))
    batch.to(model.device)
    outputs = model.validation_step(batch)
    assert keys == set(outputs.keys())
    assert isinstance(outputs["loss"], torch.Tensor)

    # test test step
    batch = next(iter(loaders[2]))
    batch.to(model.device)
    outputs = model.test_step(batch)
    assert keys == set(outputs.keys())
    assert isinstance(outputs["loss"], torch.Tensor)

    shutil.rmtree(cfg.out_dir)
コード例 #4
0
def test_run_single_graphgym():
    Args = namedtuple('Args', ['cfg_file', 'opts'])
    root = osp.join(osp.dirname(osp.realpath(__file__)))
    args = Args(osp.join(root, 'example_node.yml'), [])

    load_cfg(cfg, args)
    cfg.out_dir = osp.join('/', 'tmp', str(random.randrange(sys.maxsize)))
    cfg.run_dir = osp.join('/', 'tmp', str(random.randrange(sys.maxsize)))
    cfg.dataset.dir = osp.join('/', 'tmp', str(random.randrange(sys.maxsize)))
    dump_cfg(cfg)
    set_printing()

    seed_everything(cfg.seed)
    auto_select_device()
    set_run_dir(cfg.out_dir, args.cfg_file)

    loaders = create_loader()
    assert len(loaders) == 3

    loggers = create_logger()
    assert len(loggers) == 3

    model = create_model()
    assert isinstance(model, torch.nn.Module)
    assert isinstance(model.encoder, FeatureEncoder)
    assert isinstance(model.mp, GNNStackStage)
    assert isinstance(model.post_mp, GNNNodeHead)
    assert len(list(model.pre_mp.children())) == cfg.gnn.layers_pre_mp

    optimizer_config = OptimizerConfig(optimizer=cfg.optim.optimizer,
                                       base_lr=cfg.optim.base_lr,
                                       weight_decay=cfg.optim.weight_decay,
                                       momentum=cfg.optim.momentum)
    optimizer = create_optimizer(model.parameters(), optimizer_config)
    assert isinstance(optimizer, torch.optim.Adam)

    scheduler_config = SchedulerConfig(scheduler=cfg.optim.scheduler,
                                       steps=cfg.optim.steps,
                                       lr_decay=cfg.optim.lr_decay,
                                       max_epoch=cfg.optim.max_epoch)
    scheduler = create_scheduler(optimizer, scheduler_config)
    assert isinstance(scheduler, torch.optim.lr_scheduler.CosineAnnealingLR)

    cfg.params = params_count(model)
    assert cfg.params == 23880

    train(loggers, loaders, model, optimizer, scheduler)

    agg_runs(set_agg_dir(cfg.out_dir, args.cfg_file), cfg.metric_best)

    shutil.rmtree(cfg.out_dir)
    shutil.rmtree(cfg.dataset.dir)
コード例 #5
0
def test_run_single_graphgym(skip_train_eval, use_trivial_metric):
    Args = namedtuple('Args', ['cfg_file', 'opts'])
    root = osp.join(osp.dirname(osp.realpath(__file__)))
    args = Args(osp.join(root, 'example_node.yml'), [])

    load_cfg(cfg, args)
    cfg.out_dir = osp.join('/', 'tmp', str(random.randrange(sys.maxsize)))
    cfg.run_dir = osp.join('/', 'tmp', str(random.randrange(sys.maxsize)))
    cfg.dataset.dir = osp.join('/', 'tmp', str(random.randrange(sys.maxsize)))
    dump_cfg(cfg)
    set_printing()

    seed_everything(cfg.seed)
    auto_select_device()
    set_run_dir(cfg.out_dir, args.cfg_file)

    cfg.train.skip_train_eval = skip_train_eval
    cfg.train.enable_ckpt = use_trivial_metric and skip_train_eval
    if use_trivial_metric:
        if 'trivial' not in register.metric_dict:
            register.register_metric('trivial', trivial_metric)
        global num_trivial_metric_calls
        num_trivial_metric_calls = 0
        cfg.metric_best = 'trivial'
        cfg.custom_metrics = ['trivial']
    else:
        cfg.metric_best = 'auto'
        cfg.custom_metrics = []

    loaders = create_loader()
    assert len(loaders) == 3

    loggers = create_logger()
    assert len(loggers) == 3

    model = create_model()
    assert isinstance(model, torch.nn.Module)
    assert isinstance(model.encoder, FeatureEncoder)
    assert isinstance(model.mp, GNNStackStage)
    assert isinstance(model.post_mp, GNNNodeHead)
    assert len(list(model.pre_mp.children())) == cfg.gnn.layers_pre_mp

    optimizer = create_optimizer(model.parameters(), cfg.optim)
    assert isinstance(optimizer, torch.optim.Adam)

    scheduler = create_scheduler(optimizer, cfg.optim)
    assert isinstance(scheduler, torch.optim.lr_scheduler.CosineAnnealingLR)

    cfg.params = params_count(model)
    assert cfg.params == 23880

    train(loggers, loaders, model, optimizer, scheduler)

    if use_trivial_metric:
        # 6 total epochs, 4 eval epochs, 3 splits (1 training split)
        assert num_trivial_metric_calls == 12 if skip_train_eval else 14

    assert osp.isdir(get_ckpt_dir()) is cfg.train.enable_ckpt

    agg_runs(set_agg_dir(cfg.out_dir, args.cfg_file), cfg.metric_best)

    shutil.rmtree(cfg.out_dir)
    shutil.rmtree(cfg.dataset.dir)
コード例 #6
0
ファイル: main.py プロジェクト: rusty1s/pytorch_geometric
from torch_geometric.graphgym.utils.agg_runs import agg_runs
from torch_geometric.graphgym.utils.comp_budget import params_count
from torch_geometric.graphgym.utils.device import auto_select_device

if __name__ == '__main__':
    # Load cmd line args
    args = parse_args()
    # Load config file
    load_cfg(cfg, args)
    set_out_dir(cfg.out_dir, args.cfg_file)
    # Set Pytorch environment
    torch.set_num_threads(cfg.num_threads)
    dump_cfg(cfg)
    # Repeat for different random seeds
    for i in range(args.repeat):
        set_run_dir(cfg.out_dir)
        set_printing()
        # Set configurations for each run
        cfg.seed = cfg.seed + 1
        seed_everything(cfg.seed)
        auto_select_device()
        # Set machine learning pipeline
        datamodule = GraphGymDataModule()
        model = create_model()
        # Print model info
        logging.info(model)
        logging.info(cfg)
        cfg.params = params_count(model)
        logging.info('Num parameters: %s', cfg.params)
        train(model, datamodule, logger=True)
コード例 #7
0
if __name__ == '__main__':
    # Load cmd line args
    args = parse_args()
    # Load config file
    load_cfg(cfg, args)
    # Set Pytorch environment
    torch.set_num_threads(cfg.num_threads)
    dump_cfg(cfg)
    set_printing()
    # Repeat for different random seeds
    for i in range(args.repeat):
        # Set configurations for each run
        seed_everything(cfg.seed + i)
        auto_select_device()
        set_run_dir(cfg.out_dir, args.cfg_file)
        # Set machine learning pipeline
        loaders = create_loader()
        loggers = create_logger()
        model = create_model()
        optimizer = create_optimizer(model.parameters(),
                                     new_optimizer_config(cfg))
        scheduler = create_scheduler(optimizer, new_scheduler_config(cfg))
        # Print model info
        logging.info(model)
        logging.info(cfg)
        cfg.params = params_count(model)
        logging.info('Num parameters: {}'.format(cfg.params))
        # Start training
        if cfg.train.mode == 'standard':
            train(loggers, loaders, model, optimizer, scheduler)