コード例 #1
0
def main(config_file: str):
    parser = argparse.ArgumentParser()
    parser.add_argument("overrides", nargs="*", default=[])
    args = parser.parse_args()

    parsed = Path(config_file)
    initialize(config_dir=str(parsed.parent), strict=False)
    cfg = compose(parsed.name, overrides=args.overrides)
    logger.info(f"Training with the following config:\n{cfg.pretty()}")

    # we want to pass in dictionaries as OmegaConf doesn't play nicely with
    # loggers and doesn't allow non-native types
    module = NetworkLightningModule(OmegaConf.to_container(cfg, resolve=True))
    trainer = Trainer(**OmegaConf.to_container(cfg.pl_trainer, resolve=True))
    trainer.fit(
        module,
        train_dataloader=DataLoader(
            Dataset.from_data_dir(cfg.dataset.train.dir_path, transform=True),
            shuffle=True,
            batch_size=cfg.dataset.train.batch_size,
            num_workers=cfg.dataset.train.num_workers,
        ),
        val_dataloaders=DataLoader(
            Dataset.from_data_dir(cfg.dataset.val.dir_path),
            batch_size=cfg.dataset.val.batch_size,
            num_workers=cfg.dataset.val.num_workers,
        ),
    )
    if cfg.train.run_test:
        trainer.test(test_dataloaders=DataLoader(
            Dataset.from_data_dir(cfg.dataset.test.dir_path),
            batch_size=cfg.datset.train.batch_size,
            num_workers=cfg.dataset.test.num_workers,
        ))
コード例 #2
0
def main(cfg: DictConfig):
    logger.info(f'Training with the following config:\n{cfg.pretty()}')
    if cfg.train.resume_checkpoint:
        module = NetworkLightningModule.load_from_checkpoint(
            hydra.utils.to_absolute_path(cfg.train.resume_checkpoint))
    else:
        module = NetworkLightningModule({
            'board_size':
            cfg.network.board_size,
            'in_channels':
            cfg.network.in_channels,
            'residual_channels':
            cfg.network.residual_channels,
            'residual_layers':
            cfg.network.residual_layers,
            'learning_rate':
            cfg.train.learning_rate,
        })
    trainer = Trainer(
        max_epochs=cfg.train.max_epochs,
        gpus=cfg.train.gpus,
        early_stop_callback=cfg.train.early_stop,
        distributed_backend='ddp'
        if cfg.train.gpus is not None and cfg.train.gpus > 1 else None,
        train_percent_check=cfg.train.train_percent,
        val_percent_check=cfg.train.val_percent,
    )
    trainer.fit(
        module,
        train_dataloader=DataLoader(
            Dataset.from_data_dir(hydra.utils.to_absolute_path(
                cfg.dataset.train_dir),
                                  transform=True),
            shuffle=True,
            batch_size=cfg.train.batch_size,
            num_workers=cfg.train.n_data_workers,
        ),
        val_dataloaders=DataLoader(
            Dataset.from_data_dir(
                hydra.utils.to_absolute_path(cfg.dataset.val_dir)),
            batch_size=cfg.train.batch_size,
            num_workers=cfg.train.n_data_workers,
        ),
        test_dataloaders=DataLoader(
            Dataset.from_data_dir(
                hydra.utils.to_absolute_path(cfg.dataset.test_dir)),
            batch_size=cfg.train.batch_size,
            num_workers=cfg.train.n_data_workers,
        ),
    )
    if cfg.train.run_test:
        trainer.test()
コード例 #3
0
def test_train(tmp_path):
    module = NetworkLightningModule({
        'board_size': 19,
        'in_channels': 18,
        'residual_channels': 1,
        'residual_layers': 1,
        'learning_rate': 0.05,
    })
    trainer = Trainer(fast_dev_run=True, default_save_path=tmp_path)
    train_dataset = Dataset.from_data_dir('test-data', transform=True)
    dataset = Dataset.from_data_dir('test-data')
    trainer.fit(
        module,
        train_dataloader=DataLoader(train_dataset, batch_size=2, shuffle=True),
        val_dataloaders=DataLoader(dataset, batch_size=2),
        test_dataloaders=DataLoader(dataset, batch_size=2),
    )
コード例 #4
0
def test_go_dataset(filenames: List[str], length: int, transform: bool):
    view = Dataset(filenames, transform)
    assert len(view) == length
    random_idx = random.randrange(0, len(view))

    planes, moves, outcome = view[random_idx]
    assert planes.size() == (18, 19, 19)
    assert moves.item() in list(range(19 * 19 + 1))
    assert moves.dtype == torch.int64
    assert outcome.item() in (-1, 1)
コード例 #5
0
def test_go_dataset(filenames: List[str], length: int, transform: bool):
    dataset = Dataset(filenames, transform)
    assert len(dataset) == length
    for i in range(len(dataset)):
        planes, moves, outcome = dataset[i]
        assert planes.size() == (18, 19, 19)
        assert moves.item() in list(range(19 * 19 + 1))
        assert moves.dtype == torch.int64
        assert outcome.item() in (-1, 1)
        assert outcome.dtype == torch.float32
コード例 #6
0
def test_train():
    module = NetworkLightningModule({
        'board_size': 19,
        'in_channels': 18,
        'residual_channels': 1,
        'residual_layers': 1,
    })
    trainer = Trainer(fast_dev_run=True)
    dataset = Dataset.from_data_dir('test-data')
    trainer.fit(
        module,
        train_dataloader=DataLoader(dataset, batch_size=2),
        val_dataloaders=DataLoader(dataset, batch_size=2),
        test_dataloaders=DataLoader(dataset, batch_size=2),
    )